Commit 3fb4b5fa authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.18.0' into v0.18.0-ori

parents bcf25339 89138b21
......@@ -13,7 +13,7 @@ repos:
args: [--output-format, github, --fix]
- id: ruff-format
- repo: https://github.com/crate-ci/typos
rev: v1.38.1
rev: v1.43.5
hooks:
- id: typos
args: [--force-exclude]
......@@ -24,12 +24,13 @@ repos:
exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))|vllm/third_party/.*'
types_or: [c++, cuda]
args: [--style=file, --verbose]
- repo: https://github.com/igorshubovych/markdownlint-cli
rev: v0.45.0
- repo: https://github.com/DavidAnson/markdownlint-cli2
rev: v0.21.0
hooks:
- id: markdownlint
exclude: '.*\.inc\.md'
stages: [manual] # Only run in CI
- id: markdownlint-cli2
language_version: lts
args: [--fix]
exclude: ^CLAUDE\.md$
- repo: https://github.com/rhysd/actionlint
rev: v1.7.7
hooks:
......@@ -55,7 +56,7 @@ repos:
language: python
types_or: [python, pyi]
require_serial: true
additional_dependencies: [mypy==1.11.1, regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic]
additional_dependencies: ["mypy[faster-cache]==1.19.1", regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic]
- id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.10
entry: python tools/pre_commit/mypy.py 1 "3.10"
......@@ -127,6 +128,13 @@ repos:
language: python
types: [python]
additional_dependencies: [regex]
# prevent use torch.cuda APIs
- id: check-torch-cuda-call
name: "Prevent new 'torch.cuda' APIs call"
entry: python tools/pre_commit/check_torch_cuda.py
language: python
types: [python]
additional_dependencies: [regex]
- id: validate-config
name: Validate configuration has default values and that each field has a docstring
entry: python tools/pre_commit/validate_config.py
......@@ -143,6 +151,11 @@ repos:
name: Check attention backend documentation is up to date
entry: python tools/pre_commit/generate_attention_backend_docs.py --check
language: python
- id: check-boolean-context-manager
name: Check for boolean ops in with-statements
entry: python tools/pre_commit/check_boolean_context_manager.py
language: python
types: [python]
# Keep `suggestion` last
- id: suggestion
name: Suggestion
......
......@@ -9,13 +9,15 @@ build:
python: "3.12"
jobs:
post_checkout:
- git fetch --unshallow || true
# - bash docs/maybe_skip_pr_build.sh
- git fetch origin main --unshallow --no-tags --filter=blob:none || true
pre_create_environment:
- pip install uv
create_environment:
- uv venv $READTHEDOCS_VIRTUALENV_PATH
install:
- uv pip install --python $READTHEDOCS_VIRTUALENV_PATH/bin/python --no-cache-dir -r requirements/docs.txt
mkdocs:
configuration: mkdocs.yaml
fail_on_warning: true
# Optionally declare the Python requirements required to build your docs
python:
install:
- requirements: requirements/docs.txt
# Agent Instructions for vLLM
> These instructions apply to **all** AI-assisted contributions to `vllm-project/vllm`.
> Breaching these guidelines can result in automatic banning.
## 1. Contribution Policy (Mandatory)
### Duplicate-work checks
Before proposing a PR, run these checks:
```bash
gh issue view <issue_number> --repo vllm-project/vllm --comments
gh pr list --repo vllm-project/vllm --state open --search "<issue_number> in:body"
gh pr list --repo vllm-project/vllm --state open --search "<short area keywords>"
```
- If an open PR already addresses the same fix, do not open another.
- If your approach is materially different, explain the difference in the issue.
### No low-value busywork PRs
Do not open one-off PRs for tiny edits (single typo, isolated style change, one mutable default, etc.). Mechanical cleanups are acceptable only when bundled with substantive work.
### Accountability
- Pure code-agent PRs are **not allowed**. A human submitter must understand and defend the change end-to-end.
- The submitting human must review every changed line and run relevant tests.
- PR descriptions for AI-assisted work **must** include:
- Why this is not duplicating an existing PR.
- Test commands run and results.
- Clear statement that AI assistance was used.
### Fail-closed behavior
If work is duplicate/trivial busywork, **do not proceed**. Return a short explanation of what is missing.
---
## 2. Development Workflow
### Environment setup
```bash
# Install `uv` if you don't have it already:
curl -LsSf https://astral.sh/uv/install.sh | sh
# Always use `uv` for Python environment management:
uv venv --python 3.12
source .venv/bin/activate
# Always make sure `pre-commit` and its hooks are installed:
uv pip install -r requirements/lint.txt
pre-commit install
```
### Installing dependencies
```bash
# If you are only making Python changes:
VLLM_USE_PRECOMPILED=1 uv pip install -e .
# If you are also making C/C++ changes:
uv pip install -e .
```
### Running tests
Tests require extra dependencies.
All versions for test dependencies should be read from `requirements/test.txt`
```bash
# Install bare minimum test dependencies:
uv pip install pytest pytest-asyncio tblib
# Install additional test dependencies as needed, or install them all as follows:
uv pip install -r requirements/test.txt
# Run specific test from specific test file
pytest tests/path/to/test.py -v -s -k test_name
# Run all tests in directory
pytest tests/path/to/dir -v -s
```
### Running linters
```bash
# Run all pre-commit hooks on staged files:
pre-commit run
# Run on all files:
pre-commit run --all-files
# Run a specific hook:
pre-commit run ruff-check --all-files
# Run mypy as it is in CI:
pre-commit run mypy-3.10 --all-files --hook-stage manual
```
### Commit messages
Add attribution using commit trailers such as `Co-authored-by:` (other projects use `Assisted-by:` or `Generated-by:`). For example:
```text
Your commit message here
Co-authored-by: GitHub Copilot
Co-authored-by: Claude
Co-authored-by: gemini-code-assist
Signed-off-by: Your Name <your.email@example.com>
```
@AGENTS.md
......@@ -37,7 +37,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
set(PYTHON_SUPPORTED_VERSIONS "3.10" "3.11" "3.12" "3.13")
# Supported AMD GPU architectures.
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201;gfx1150;gfx1151;gfx928;gfx936;gfx938")
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1150;gfx1151;gfx1152;gfx1153;gfx1200;gfx1201;gfx928;gfx936;gfx938")
# ROCm installation prefix. Default to /opt/rocm but allow override via
# -DROCM_PATH=/your/rocm/path when invoking cmake.
......@@ -293,6 +293,7 @@ set(VLLM_EXT_SRC
"csrc/fused_qknorm_rope_kernel.cu"
# "csrc/layernorm_quant_kernels.cu"
"csrc/sampler.cu"
"csrc/topk.cu"
"csrc/cuda_view.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/w8a8/int8/scaled_quant.cu"
......@@ -724,7 +725,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# CUTLASS MoE kernels
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works
# on Hopper). get_cutlass_(pplx_)moe_mm_data should only be compiled
# on Hopper). get_cutlass_(batched_)moe_mm_data should only be compiled
# if it's possible to compile MoE kernels that use its output.
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
......@@ -770,6 +771,51 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
endif()
# Expert-specialization MXFP8 blockscaled grouped kernels (SM100+).
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND ES_MXFP8_GROUPED_MM_ARCHS)
set(SRCS
"csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu"
"csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${ES_MXFP8_GROUPED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_ES_MXFP8_GROUPED_MM_SM100=1")
message(STATUS "Building ES MXFP8 grouped kernels for archs: ${ES_MXFP8_GROUPED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8
AND ES_MXFP8_GROUPED_MM_ARCHS)
message(STATUS "Not building ES MXFP8 grouped kernels as CUDA Compiler version is "
"not >= 12.8.")
else()
message(STATUS "Not building ES MXFP8 grouped kernels as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
# DeepSeek V3 fused A GEMM kernel (requires SM 9.0+, Hopper and later)
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND DSV3_FUSED_A_GEMM_ARCHS)
set(DSV3_FUSED_A_GEMM_SRC "csrc/dsv3_fused_a_gemm.cu")
set_gencode_flags_for_srcs(
SRCS "${DSV3_FUSED_A_GEMM_SRC}"
CUDA_ARCHS "${DSV3_FUSED_A_GEMM_ARCHS}")
list(APPEND VLLM_EXT_SRC ${DSV3_FUSED_A_GEMM_SRC})
message(STATUS "Building dsv3_fused_a_gemm for archs: ${DSV3_FUSED_A_GEMM_ARCHS}")
else()
message(STATUS "Not building dsv3_fused_a_gemm as no compatible archs found "
"in CUDA target architectures.")
endif()
# moe_data.cu is used by all CUTLASS MoE kernels.
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
......@@ -952,7 +998,8 @@ set(VLLM_MOE_EXT_SRC
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_MOE_EXT_SRC
"csrc/moe/moe_wna16.cu"
"csrc/moe/grouped_topk_kernels.cu")
"csrc/moe/grouped_topk_kernels.cu"
"csrc/moe/router_gemm.cu")
endif()
if(VLLM_GPU_LANG STREQUAL "CUDA")
......@@ -1081,6 +1128,27 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Not building Marlin MOE kernels as no compatible archs found"
" in CUDA target architectures")
endif()
# DeepSeek V3 router GEMM kernel - requires SM90+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(DSV3_ROUTER_GEMM_ARCHS "9.0a;10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(DSV3_ROUTER_GEMM_ARCHS "9.0a;10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND DSV3_ROUTER_GEMM_ARCHS)
set(DSV3_ROUTER_GEMM_SRC
"csrc/moe/dsv3_router_gemm_entry.cu"
"csrc/moe/dsv3_router_gemm_float_out.cu"
"csrc/moe/dsv3_router_gemm_bf16_out.cu")
set_gencode_flags_for_srcs(
SRCS "${DSV3_ROUTER_GEMM_SRC}"
CUDA_ARCHS "${DSV3_ROUTER_GEMM_ARCHS}")
list(APPEND VLLM_MOE_EXT_SRC "${DSV3_ROUTER_GEMM_SRC}")
message(STATUS "Building DSV3 router GEMM kernel for archs: ${DSV3_ROUTER_GEMM_ARCHS}")
else()
message(STATUS "Not building DSV3 router GEMM kernel as no compatible archs found"
" (requires SM90+ and CUDA >= 12.0)")
endif()
endif()
message(STATUS "Enabling moe extension.")
......
......@@ -187,7 +187,7 @@ python benchmark.py \
## Hardware Requirements
| Backend | Hardware |
|---------|----------|
| ------- | -------- |
| Flash/Triton/FlashInfer | Any CUDA GPU |
| CUTLASS MLA | Blackwell (SM100+) |
| FlashAttn MLA | Hopper (SM90+) |
......
......@@ -15,7 +15,6 @@ from .common import (
BenchmarkConfig,
BenchmarkResult,
MockLayer,
MockModelConfig,
ResultsFormatter,
get_attention_scale,
is_mla_backend,
......@@ -36,7 +35,6 @@ __all__ = [
"ResultsFormatter",
# Mock objects
"MockLayer",
"MockModelConfig",
# Utilities
"setup_mla_dims",
"get_attention_scale",
......
......@@ -229,3 +229,40 @@ def get_batch_stats(requests: list[BatchRequest]) -> dict:
sum(r.kv_len for r in requests) / len(requests) if requests else 0
),
}
def get_batch_type(batch_spec: str, spec_decode_threshold: int = 8) -> str:
"""
Classify a batch spec into a type string.
Args:
batch_spec: Batch specification string (e.g., "q2k", "8q1s1k", "2q2k_8q1s1k")
spec_decode_threshold: Max q_len to be considered spec-decode vs extend
Returns:
Type string: "prefill", "decode", "spec-decode", "extend", or "mixed (types...)"
"""
requests = parse_batch_spec(batch_spec)
# Classify each request
types_present = set()
for req in requests:
if req.is_decode:
types_present.add("decode")
elif req.is_prefill:
types_present.add("prefill")
elif req.is_extend:
# Distinguish spec-decode (small q_len) from extend (chunked prefill)
if req.q_len <= spec_decode_threshold:
types_present.add("spec-decode")
else:
types_present.add("extend")
if len(types_present) == 1:
return types_present.pop()
elif len(types_present) > 1:
# Sort for consistent output
sorted_types = sorted(types_present)
return f"mixed ({'+'.join(sorted_types)})"
else:
return "unknown"
......@@ -43,6 +43,7 @@ from common import (
ModelParameterSweep,
ParameterSweep,
ResultsFormatter,
batch_spec_sort_key,
is_mla_backend,
)
......@@ -58,7 +59,9 @@ def run_mla_benchmark(config: BenchmarkConfig, **kwargs) -> BenchmarkResult:
"""Run MLA benchmark with appropriate backend."""
from mla_runner import run_mla_benchmark as run_mla
return run_mla(config.backend, config, **kwargs)
return run_mla(
config.backend, config, prefill_backend=config.prefill_backend, **kwargs
)
def run_benchmark(config: BenchmarkConfig, **kwargs) -> BenchmarkResult:
......@@ -218,10 +221,13 @@ def run_model_parameter_sweep(
by_param_and_spec[key].append(r)
break
# Sort by param value then spec
# Sort by param value then spec (batch_size, q_len, kv_len)
sorted_keys = sorted(
by_param_and_spec.keys(),
key=lambda x: (int(x[0]) if x[0].isdigit() else x[0], x[1]),
key=lambda x: (
int(x[0]) if x[0].isdigit() else x[0],
batch_spec_sort_key(x[1]),
),
)
current_param_value = None
......@@ -330,7 +336,7 @@ def run_parameter_sweep(
by_spec[spec] = []
by_spec[spec].append(r)
for spec in sorted(by_spec.keys()):
for spec in sorted(by_spec.keys(), key=batch_spec_sort_key):
results = by_spec[spec]
best = min(results, key=lambda r: r.mean_time)
console.print(
......@@ -436,14 +442,21 @@ def main():
# Backend selection
parser.add_argument(
"--backends",
"--decode-backends",
nargs="+",
help="Backends to benchmark (flash, triton, flashinfer, cutlass_mla, "
help="Decode backends to benchmark (flash, triton, flashinfer, cutlass_mla, "
"flashinfer_mla, flashattn_mla, flashmla)",
)
parser.add_argument(
"--backend",
help="Single backend (alternative to --backends)",
)
parser.add_argument(
"--prefill-backends",
nargs="+",
help="Prefill backends to compare (fa2, fa3, fa4). "
"Uses the first decode backend for impl construction.",
)
# Batch specifications
parser.add_argument(
......@@ -496,15 +509,24 @@ def main():
if "description" in yaml_config:
console.print(f"[dim]{yaml_config['description']}[/]")
# Override args with YAML values
# (YAML takes precedence unless CLI arg was explicitly set)
# Backend(s)
if "backend" in yaml_config:
args.backend = yaml_config["backend"]
args.backends = None
elif "backends" in yaml_config:
args.backends = yaml_config["backends"]
args.backend = None
# Override args with YAML values, but CLI args take precedence
# Check if CLI provided backends (they would be non-None and not default)
cli_backends_provided = args.backend is not None or args.backends is not None
# Backend(s) - only use YAML if CLI didn't specify
if not cli_backends_provided:
if "backend" in yaml_config:
args.backend = yaml_config["backend"]
args.backends = None
elif "backends" in yaml_config:
args.backends = yaml_config["backends"]
args.backend = None
elif "decode_backends" in yaml_config:
args.backends = yaml_config["decode_backends"]
args.backend = None
# Prefill backends (e.g., ["fa3", "fa4"])
args.prefill_backends = yaml_config.get("prefill_backends", None)
# Check for special modes
if "mode" in yaml_config:
......@@ -544,13 +566,15 @@ def main():
args.num_kv_heads = model.get("num_kv_heads", args.num_kv_heads)
args.block_size = model.get("block_size", args.block_size)
# Benchmark settings
if "benchmark" in yaml_config:
bench = yaml_config["benchmark"]
args.device = bench.get("device", args.device)
args.repeats = bench.get("repeats", args.repeats)
args.warmup_iters = bench.get("warmup_iters", args.warmup_iters)
args.profile_memory = bench.get("profile_memory", args.profile_memory)
# Benchmark settings (top-level keys)
if "device" in yaml_config:
args.device = yaml_config["device"]
if "repeats" in yaml_config:
args.repeats = yaml_config["repeats"]
if "warmup_iters" in yaml_config:
args.warmup_iters = yaml_config["warmup_iters"]
if "profile_memory" in yaml_config:
args.profile_memory = yaml_config["profile_memory"]
# Parameter sweep configuration
if "parameter_sweep" in yaml_config:
......@@ -604,7 +628,10 @@ def main():
# Determine backends
backends = args.backends or ([args.backend] if args.backend else ["flash"])
prefill_backends = getattr(args, "prefill_backends", None)
console.print(f"Backends: {', '.join(backends)}")
if prefill_backends:
console.print(f"Prefill backends: {', '.join(prefill_backends)}")
console.print(f"Batch specs: {', '.join(args.batch_specs)}")
console.print()
......@@ -841,37 +868,93 @@ def main():
else:
# Normal mode: compare backends
total = len(backends) * len(args.batch_specs)
decode_results = []
prefill_results = []
with tqdm(total=total, desc="Benchmarking") as pbar:
for spec in args.batch_specs:
for backend in backends:
config = BenchmarkConfig(
backend=backend,
batch_spec=spec,
num_layers=args.num_layers,
head_dim=args.head_dim,
num_q_heads=args.num_q_heads,
num_kv_heads=args.num_kv_heads,
block_size=args.block_size,
device=args.device,
repeats=args.repeats,
warmup_iters=args.warmup_iters,
profile_memory=args.profile_memory,
)
# Run decode backend comparison
if not prefill_backends:
# No prefill backends specified: compare decode backends as before
total = len(backends) * len(args.batch_specs)
result = run_benchmark(config)
all_results.append(result)
with tqdm(total=total, desc="Benchmarking") as pbar:
for spec in args.batch_specs:
for backend in backends:
config = BenchmarkConfig(
backend=backend,
batch_spec=spec,
num_layers=args.num_layers,
head_dim=args.head_dim,
num_q_heads=args.num_q_heads,
num_kv_heads=args.num_kv_heads,
block_size=args.block_size,
device=args.device,
repeats=args.repeats,
warmup_iters=args.warmup_iters,
profile_memory=args.profile_memory,
)
if not result.success:
console.print(f"[red]Error {backend} {spec}: {result.error}[/]")
result = run_benchmark(config)
decode_results.append(result)
pbar.update(1)
if not result.success:
console.print(
f"[red]Error {backend} {spec}: {result.error}[/]"
)
# Display results
console.print("\n[bold green]Results:[/]")
formatter = ResultsFormatter(console)
formatter.print_table(all_results, backends)
pbar.update(1)
console.print("\n[bold green]Results:[/]")
formatter = ResultsFormatter(console)
formatter.print_table(decode_results, backends)
# Run prefill backend comparison
if prefill_backends:
# Use first decode backend for impl construction
decode_backend = backends[0]
total = len(prefill_backends) * len(args.batch_specs)
console.print(
f"[yellow]Prefill comparison mode: "
f"using {decode_backend} for decode impl[/]"
)
with tqdm(total=total, desc="Prefill benchmarking") as pbar:
for spec in args.batch_specs:
for pb in prefill_backends:
config = BenchmarkConfig(
backend=decode_backend,
batch_spec=spec,
num_layers=args.num_layers,
head_dim=args.head_dim,
num_q_heads=args.num_q_heads,
num_kv_heads=args.num_kv_heads,
block_size=args.block_size,
device=args.device,
repeats=args.repeats,
warmup_iters=args.warmup_iters,
profile_memory=args.profile_memory,
prefill_backend=pb,
)
result = run_benchmark(config)
# Label result with prefill backend name for display
labeled_config = replace(result.config, backend=pb)
result = replace(result, config=labeled_config)
prefill_results.append(result)
if not result.success:
console.print(f"[red]Error {pb} {spec}: {result.error}[/]")
pbar.update(1)
console.print("\n[bold green]Prefill Backend Results:[/]")
formatter = ResultsFormatter(console)
formatter.print_table(
prefill_results, prefill_backends, compare_to_fastest=True
)
all_results = decode_results + prefill_results
# Save results
if all_results:
......
......@@ -10,18 +10,37 @@ from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any
import numpy as np
import torch
from batch_spec import get_batch_type, parse_batch_spec
from rich.console import Console
from rich.table import Table
def batch_spec_sort_key(spec: str) -> tuple[int, int, int]:
"""
Extract sorting key from batch spec: (batch_size, max_q_len, max_kv_len).
This ensures results are sorted by batch size first, then query length,
then sequence length, rather than alphabetically.
"""
try:
requests = parse_batch_spec(spec)
batch_size = len(requests)
max_q_len = max(r.q_len for r in requests) if requests else 0
max_kv_len = max(r.kv_len for r in requests) if requests else 0
return (batch_size, max_q_len, max_kv_len)
except Exception:
# Fallback for unparsable specs
return (0, 0, 0)
# Mock classes for vLLM attention infrastructure
class MockHfConfig:
"""Mock HuggingFace config that satisfies vLLM's requirements."""
def __init__(self, mla_dims: dict):
def __init__(self, mla_dims: dict, index_topk: int | None = None):
self.num_attention_heads = mla_dims["num_q_heads"]
self.num_key_value_heads = mla_dims["num_kv_heads"]
self.hidden_size = mla_dims["head_dim"] * mla_dims["num_q_heads"]
......@@ -32,6 +51,8 @@ class MockHfConfig:
self.qk_rope_head_dim = mla_dims["qk_rope_head_dim"]
self.v_head_dim = mla_dims["v_head_dim"]
self.qk_head_dim = mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"]
if index_topk is not None:
self.index_topk = index_topk
def get_text_config(self):
return self
......@@ -40,10 +61,7 @@ class MockHfConfig:
# Import AttentionLayerBase at module level to avoid circular dependencies
try:
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
_HAS_ATTENTION_LAYER_BASE = True
except ImportError:
_HAS_ATTENTION_LAYER_BASE = False
AttentionLayerBase = object # Fallback
......@@ -59,6 +77,7 @@ class MockKVBProj:
self.qk_nope_head_dim = qk_nope_head_dim
self.v_head_dim = v_head_dim
self.out_dim = qk_nope_head_dim + v_head_dim
self.weight = torch.empty(0, dtype=torch.bfloat16)
def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor]:
"""
......@@ -82,6 +101,38 @@ class MockKVBProj:
return (result,) # Return as tuple to match ColumnParallelLinear API
class MockIndexer:
"""Mock Indexer for sparse MLA backends.
Provides topk_indices_buffer that sparse MLA backends use to determine
which KV cache slots to attend to for each token.
"""
def __init__(
self,
max_num_tokens: int,
topk_tokens: int,
device: torch.device,
):
self.topk_tokens = topk_tokens
self.topk_indices_buffer = torch.zeros(
(max_num_tokens, topk_tokens),
dtype=torch.int32,
device=device,
)
def fill_random_indices(self, num_tokens: int, max_kv_len: int):
"""Fill topk_indices_buffer with random valid indices for benchmarking."""
indices = torch.randint(
0,
max_kv_len,
(num_tokens, self.topk_tokens),
dtype=torch.int32,
device=self.topk_indices_buffer.device,
)
self.topk_indices_buffer[:num_tokens] = indices
class MockLayer(AttentionLayerBase):
"""Mock attention layer with scale parameters and impl.
......@@ -113,95 +164,6 @@ class MockLayer(AttentionLayerBase):
return self._kv_cache_spec
class MockModelConfig:
"""Mock model configuration."""
def __init__(
self,
num_q_heads: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype = torch.float16,
max_model_len: int = 32768,
):
self._n_q = num_q_heads
self._n_kv = num_kv_heads
self._d = head_dim
self.dtype = dtype
self.max_model_len = max_model_len
def get_num_attention_heads(self, _=None) -> int:
return self._n_q
def get_num_kv_heads(self, _=None) -> int:
return self._n_kv
def get_head_size(self) -> int:
return self._d
def get_num_layers(self) -> int:
"""Mock method for layer count queries."""
return 1
def get_sliding_window_for_layer(self, _layer_idx: int):
"""Mock method for sliding window queries."""
return None
def get_logits_soft_cap_for_layer(self, _layer_idx: int):
"""Mock method for logits soft cap queries."""
return None
def get_sm_scale_for_layer(self, _layer_idx: int) -> float:
"""Mock method for SM scale queries."""
return 1.0 / (self.get_head_size() ** 0.5)
class MockParallelConfig:
"""Mock parallel configuration."""
pass
class MockCompilationConfig:
"""Mock compilation configuration."""
def __init__(self):
self.full_cuda_graph = False
self.static_forward_context = {}
class MockVLLMConfig:
"""Mock VLLM configuration."""
def __init__(self):
self.compilation_config = MockCompilationConfig()
class MockRunner:
"""Mock GPU runner for metadata builders."""
def __init__(
self,
seq_lens: np.ndarray,
query_start_locs: np.ndarray,
device: torch.device,
num_q_heads: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype,
):
self.model_config = MockModelConfig(num_q_heads, num_kv_heads, head_dim, dtype)
self.parallel_config = MockParallelConfig()
self.vllm_config = MockVLLMConfig()
self.seq_lens_np = seq_lens
self.query_start_loc_np = query_start_locs
self.device = device
self.attention_chunk_size = None
self.num_query_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.dtype = dtype
@dataclass
class ParameterSweep:
"""Configuration for sweeping a backend parameter."""
......@@ -252,6 +214,7 @@ class BenchmarkConfig:
use_cuda_graphs: bool = False
# MLA-specific
prefill_backend: str | None = None
kv_lora_rank: int | None = None
qk_nope_head_dim: int | None = None
qk_rope_head_dim: int | None = None
......@@ -316,14 +279,19 @@ class ResultsFormatter:
backends: List of backend names being compared
compare_to_fastest: Show percentage comparison to fastest
"""
# Group by batch spec
# Group by batch spec, preserving first-occurrence order
by_spec = {}
specs_order = []
for r in results:
spec = r.config.batch_spec
if spec not in by_spec:
by_spec[spec] = {}
specs_order.append(spec)
by_spec[spec][r.config.backend] = r
# Sort specs by (batch_size, q_len, kv_len) instead of alphabetically
specs_order = sorted(by_spec.keys(), key=batch_spec_sort_key)
# Create shortened backend names for display
def shorten_backend_name(name: str) -> str:
"""Shorten long backend names for table display."""
......@@ -337,6 +305,8 @@ class ResultsFormatter:
table = Table(title="Attention Benchmark Results")
table.add_column("Batch\nSpec", no_wrap=True)
table.add_column("Type", no_wrap=True)
table.add_column("Batch\nSize", justify="right", no_wrap=True)
multi = len(backends) > 1
for backend in backends:
......@@ -350,12 +320,14 @@ class ResultsFormatter:
table.add_column(col_rel, justify="right", no_wrap=False)
# Add rows
for spec in sorted(by_spec.keys()):
for spec in specs_order:
spec_results = by_spec[spec]
times = {b: r.mean_time for b, r in spec_results.items() if r.success}
best_time = min(times.values()) if times else 0.0
row = [spec]
batch_type = get_batch_type(spec)
batch_size = len(parse_batch_spec(spec))
row = [spec, batch_type, str(batch_size)]
for backend in backends:
if backend in spec_results:
r = spec_results[backend]
......@@ -486,10 +458,11 @@ def get_attention_scale(head_dim: int) -> float:
def is_mla_backend(backend: str) -> bool:
"""
Check if backend is an MLA backend using the backend's is_mla() property.
Check if backend is an MLA backend using the AttentionBackendEnum.
Args:
backend: Backend name (e.g., "CUTLASS_MLA", "FLASHINFER_MLA")
backend: Backend name matching AttentionBackendEnum exactly
(e.g., "FLASHMLA_SPARSE")
Returns:
True if the backend is an MLA backend, False otherwise
......@@ -497,7 +470,8 @@ def is_mla_backend(backend: str) -> bool:
from vllm.v1.attention.backends.registry import AttentionBackendEnum
try:
backend_class = AttentionBackendEnum[backend.upper()].get_class()
backend_enum = AttentionBackendEnum[backend]
backend_class = backend_enum.get_class()
return backend_class.is_mla()
except (KeyError, ValueError, ImportError):
except (KeyError, ValueError, ImportError, AttributeError):
return False
......@@ -3,7 +3,7 @@
model:
name: "deepseek-v3"
num_layers: 60
num_q_heads: 128
num_q_heads: 128 # Base value, can be swept for TP simulation
num_kv_heads: 1 # MLA uses single latent KV
head_dim: 576
kv_lora_rank: 512
......@@ -12,6 +12,13 @@ model:
v_head_dim: 128
block_size: 128 # CUTLASS MLA and FlashAttn MLA use 128
# Model parameter sweep: simulate tensor parallelism by varying num_q_heads
# TP=1: 128 heads, TP=2: 64 heads, TP=4: 32 heads, TP=8: 16 heads
model_parameter_sweep:
param_name: "num_q_heads"
values: [128, 64, 32, 16]
label_format: "{backend}_{value}h"
batch_specs:
# Small batches, varying sequence lengths
- "16q1s512" # 16 requests, 512 KV cache
......@@ -34,28 +41,30 @@ batch_specs:
# Very large batches
- "128q1s1k" # 128 requests, 1k KV cache
- "128q1s2k" # 128 requests, 2k KV cache
- "128q1s4k" # 128 requests, 4k KV cache
- "128q1s8k" # 128 requests, 8k KV cache
# Long context
- "32q1s16k" # 32 requests, 16k KV cache
- "32q1s32k" # 32 requests, 32k KV cache
backends:
- cutlass_mla
- flashinfer_mla
- flashattn_mla # Hopper only
- flashmla # Hopper only
- CUTLASS_MLA
- FLASHINFER_MLA
- FLASH_ATTN_MLA # Hopper only
- FLASHMLA # Hopper only
device: "cuda:0"
repeats: 5
warmup_iters: 3
repeats: 100
warmup_iters: 10
profile_memory: true
# Backend-specific tuning
cutlass_mla:
CUTLASS_MLA:
num_kv_splits: auto # or specific value like 4, 8, 16
flashattn_mla:
FLASH_ATTN_MLA:
reorder_batch_threshold: 512
flashmla:
FLASHMLA:
reorder_batch_threshold: 1
......@@ -45,10 +45,10 @@ batch_specs:
- "4q4k_60q1s4k" # 4 prefill + 60 decode
backends:
- cutlass_mla
- flashinfer_mla
- flashattn_mla # Hopper only
- flashmla # Hopper only
- CUTLASS_MLA
- FLASHINFER_MLA
- FLASH_ATTN_MLA # Hopper only
- FLASHMLA # Hopper only
device: "cuda:0"
repeats: 5
......
# MLA prefill backend comparison
#
# Compares all available MLA prefill backends:
# FA backends: fa2, fa3, fa4 (FlashAttention versions)
# Non-FA: flashinfer, cudnn, trtllm (Blackwell-only, require flashinfer)
#
# Uses cutlass_mla as the decode backend for impl construction
# (only the prefill path is exercised).
#
# Backends that aren't available on the current platform will report errors
# in the results table (e.g., fa3 on Blackwell, cudnn without artifactory).
#
# Usage:
# python benchmark.py --config configs/mla_prefill.yaml
description: "MLA prefill backend comparison"
model:
name: "deepseek-v3"
num_layers: 60
num_q_heads: 128
num_kv_heads: 1
head_dim: 576
kv_lora_rank: 512
qk_nope_head_dim: 128
qk_rope_head_dim: 64
v_head_dim: 128
block_size: 128
# model:
# name: "deepseek-v2-lite"
# num_layers: 27
# num_q_heads: 16
# num_kv_heads: 1
# head_dim: 576
# kv_lora_rank: 512
# qk_nope_head_dim: 128
# qk_rope_head_dim: 64
# v_head_dim: 128
# block_size: 128
batch_specs:
# Pure prefill
- "q512"
- "q1k"
- "q2k"
- "q4k"
- "q8k"
# Batched pure prefill
- "2q512"
- "2q1k"
- "2q2k"
- "2q4k"
- "2q8k"
- "4q512"
- "4q1k"
- "4q2k"
- "4q4k"
- "4q8k"
- "8q512"
- "8q1k"
- "8q2k"
- "8q4k"
- "8q8k"
# Chunked prefill / extend
# Short context
- "q128s1k"
- "q256s2k"
- "q512s4k"
- "q1ks4k"
- "q2ks8k"
- "2q128s1k"
- "2q256s2k"
- "2q512s4k"
- "2q1ks4k"
- "2q2ks8k"
- "4q128s1k"
- "4q256s2k"
- "4q512s4k"
- "4q1ks4k"
- "4q2ks8k"
- "8q128s1k"
- "8q256s2k"
- "8q512s4k"
- "8q1ks4k"
# Medium context
- "q128s16k"
- "q512s16k"
- "q1ks16k"
- "q2ks16k"
- "2q128s16k"
- "2q512s16k"
- "2q1ks16k"
- "2q2ks16k"
- "4q128s16k"
- "4q512s16k"
- "4q1ks16k"
- "4q2ks16k"
# Long context
- "q128s64k"
- "q512s64k"
- "q1ks64k"
- "q2ks64k"
- "2q128s64k"
- "2q512s64k"
- "2q1ks64k"
- "2q2ks64k"
decode_backends:
- CUTLASS_MLA
prefill_backends:
- fa2
- fa3
- fa4
- flashinfer
- cudnn
- trtllm
device: "cuda:0"
repeats: 20
warmup_iters: 5
# MLA prefill-only benchmark configuration for sparse backends
model:
name: "deepseek-v3"
num_layers: 60
num_q_heads: 128
num_kv_heads: 1
head_dim: 576
kv_lora_rank: 512
qk_nope_head_dim: 128
qk_rope_head_dim: 64
v_head_dim: 128
block_size: 128
# Model parameter sweep: simulate tensor parallelism by varying num_q_heads
# TP=1: 128 heads, TP=2: 64 heads, TP=4: 32 heads, TP=8: 16 heads
model_parameter_sweep:
param_name: "num_q_heads"
values: [128, 64, 32, 16]
label_format: "{backend}_{value}h"
batch_specs:
# Pure prefill
- "1q512"
- "1q1k"
- "1q2k"
- "1q4k"
- "1q8k"
# Batched pure prefill
- "2q512"
- "2q1k"
- "2q2k"
- "2q4k"
- "2q8k"
- "4q512"
- "4q1k"
- "4q2k"
- "4q4k"
- "4q8k"
- "8q512"
- "8q1k"
- "8q2k"
- "8q4k"
- "8q8k"
# Extend
- "1q512s4k"
- "1q512s8k"
- "1q1ks8k"
- "1q2ks8k"
- "1q2ks16k"
- "1q4ks16k"
backends:
- FLASHMLA_SPARSE
- FLASHINFER_MLA_SPARSE
device: "cuda:0"
repeats: 10
warmup_iters: 3
profile_memory: true
......@@ -6,7 +6,7 @@
description: "Decode vs Prefill pipeline crossover analysis"
# Test FlashAttn MLA
backend: flashattn_mla
backend: FLASH_ATTN_MLA
# Mode: decode_vs_prefill comparison (special sweep mode)
# For each batch spec, we'll test both decode and prefill pipelines
......@@ -62,11 +62,10 @@ model:
block_size: 128
# Benchmark settings
benchmark:
device: "cuda:0"
repeats: 15 # More repeats for spec decode variance
warmup_iters: 5
profile_memory: false
device: "cuda:0"
repeats: 15 # More repeats for spec decode variance
warmup_iters: 5
profile_memory: false
# Output
output:
......
......@@ -41,18 +41,17 @@ batch_specs:
# Backends that support query length > 1
backends:
- flashattn_mla # reorder_batch_threshold = 512
- flashmla # reorder_batch_threshold = 1 (tunable)
- FLASH_ATTN_MLA # reorder_batch_threshold = 512
- FLASHMLA # reorder_batch_threshold = 1 (tunable)
# FlashInfer-MLA also supports uniform spec-as-decode but with different mechanism
# - flashinfer_mla
# - FLASHINFER_MLA
# Benchmark settings
benchmark:
device: "cuda:0"
repeats: 10 # More repeats for statistical significance
warmup_iters: 5
profile_memory: false
device: "cuda:0"
repeats: 10 # More repeats for statistical significance
warmup_iters: 5
profile_memory: false
# Test these threshold values for optimization
parameter_sweep:
......
......@@ -25,14 +25,22 @@ batch_specs:
- "4q1k_16q1s2k" # 4 prefill + 16 decode
- "2q4k_32q1s1k" # 2 large prefill + 32 decode
# Context extension
- "q1ks2k" # 1k query, 2k sequence (chunked prefill)
# Speculative decode (q <= 8)
- "16q2s1k" # 16 requests, 2 spec tokens, 1k KV cache
- "16q4s1k" # 16 requests, 4 spec tokens, 1k KV cache
- "16q8s1k" # 16 requests, 8 spec tokens, 1k KV cache
- "32q4s2k" # 32 requests, 4 spec tokens, 2k KV cache
- "8q8s4k" # 8 requests, 8 spec tokens, 4k KV cache
# Context extension (chunked prefill)
- "q1ks2k" # 1k query, 2k sequence
- "2q1ks4k" # 2 requests: 1k query, 4k sequence
# Available backends: FLASH_ATTN, TRITON_ATTN, FLASHINFER
backends:
- flash
- triton
- flashinfer
- FLASH_ATTN
- TRITON_ATTN
- FLASHINFER
device: "cuda:0"
repeats: 5
......
......@@ -8,14 +8,13 @@ This module provides helpers for running MLA backends without
needing full VllmConfig integration.
"""
import importlib
import numpy as np
import torch
from batch_spec import parse_batch_spec
from common import (
BenchmarkResult,
MockHfConfig,
MockIndexer,
MockKVBProj,
MockLayer,
setup_mla_dims,
......@@ -62,6 +61,8 @@ def create_minimal_vllm_config(
block_size: int = 128,
max_num_seqs: int = 256,
mla_dims: dict | None = None,
index_topk: int | None = None,
prefill_backend: str | None = None,
) -> VllmConfig:
"""
Create minimal VllmConfig for MLA benchmarks.
......@@ -73,6 +74,11 @@ def create_minimal_vllm_config(
max_num_seqs: Maximum number of sequences
mla_dims: Optional custom MLA dimensions dict. If not provided, uses
setup_mla_dims(model_name)
index_topk: Optional topk value for sparse MLA backends. If provided,
the config will include index_topk for sparse attention.
prefill_backend: Prefill backend name (e.g., "fa3", "fa4", "flashinfer",
"cudnn", "trtllm"). Configures the attention config to
force the specified prefill backend.
Returns:
VllmConfig for benchmarking
......@@ -82,7 +88,7 @@ def create_minimal_vllm_config(
mla_dims = setup_mla_dims(model_name)
# Create mock HF config first (avoids downloading from HuggingFace)
mock_hf_config = MockHfConfig(mla_dims)
mock_hf_config = MockHfConfig(mla_dims, index_topk=index_topk)
# Create a temporary minimal config.json to avoid HF downloads
# This ensures consistent ModelConfig construction without network access
......@@ -120,16 +126,12 @@ def create_minimal_vllm_config(
seed=0,
max_model_len=32768,
quantization=None,
quantization_param_path=None,
enforce_eager=False,
max_context_len_to_capture=None,
max_seq_len_to_capture=8192,
max_logprobs=20,
disable_sliding_window=False,
skip_tokenizer_init=True,
served_model_name=None,
limit_mm_per_prompt=None,
use_async_output_proc=True,
config_format="auto",
)
finally:
......@@ -147,7 +149,6 @@ def create_minimal_vllm_config(
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
enable_prefix_caching=False,
)
......@@ -166,7 +167,7 @@ def create_minimal_vllm_config(
compilation_config = CompilationConfig()
return VllmConfig(
vllm_config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
......@@ -174,62 +175,147 @@ def create_minimal_vllm_config(
compilation_config=compilation_config,
)
if prefill_backend is not None:
prefill_cfg = get_prefill_backend_config(prefill_backend)
if prefill_cfg["flash_attn_version"] is not None:
vllm_config.attention_config.flash_attn_version = prefill_cfg[
"flash_attn_version"
]
vllm_config.attention_config.disable_flashinfer_prefill = prefill_cfg[
"disable_flashinfer_prefill"
]
vllm_config.attention_config.use_cudnn_prefill = prefill_cfg[
"use_cudnn_prefill"
]
vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill = prefill_cfg[
"use_trtllm_ragged_deepseek_prefill"
]
return vllm_config
# ============================================================================
# Backend Configuration
# Prefill Backend Configuration
# ============================================================================
# Backend name to class name prefix mapping
_BACKEND_NAME_MAP = {
"flashattn_mla": "FlashAttnMLA",
"flashmla": "FlashMLA",
"flashinfer_mla": "FlashInferMLA",
"cutlass_mla": "CutlassMLA",
# Maps prefill backend names to attention config overrides.
# FA backends set flash_attn_version and disable non-FA paths.
# Non-FA backends enable their specific path and disable others.
_PREFILL_BACKEND_CONFIG: dict[str, dict] = {
"fa2": {
"flash_attn_version": 2,
"disable_flashinfer_prefill": True,
"use_cudnn_prefill": False,
"use_trtllm_ragged_deepseek_prefill": False,
},
"fa3": {
"flash_attn_version": 3,
"disable_flashinfer_prefill": True,
"use_cudnn_prefill": False,
"use_trtllm_ragged_deepseek_prefill": False,
},
"fa4": {
"flash_attn_version": 4,
"disable_flashinfer_prefill": True,
"use_cudnn_prefill": False,
"use_trtllm_ragged_deepseek_prefill": False,
},
"flashinfer": {
"flash_attn_version": None,
"disable_flashinfer_prefill": False,
"use_cudnn_prefill": False,
"use_trtllm_ragged_deepseek_prefill": False,
},
"cudnn": {
"flash_attn_version": None,
"disable_flashinfer_prefill": True,
"use_cudnn_prefill": True,
"use_trtllm_ragged_deepseek_prefill": False,
},
"trtllm": {
"flash_attn_version": None,
"disable_flashinfer_prefill": True,
"use_cudnn_prefill": False,
"use_trtllm_ragged_deepseek_prefill": True,
},
}
# Special properties that differ from defaults
def get_prefill_backend_config(prefill_backend: str) -> dict:
"""Get attention config overrides for a prefill backend."""
if prefill_backend not in _PREFILL_BACKEND_CONFIG:
raise ValueError(
f"Unknown prefill backend: {prefill_backend!r}. "
f"Available: {list(_PREFILL_BACKEND_CONFIG.keys())}"
)
return _PREFILL_BACKEND_CONFIG[prefill_backend]
# ============================================================================
# Decode Backend Configuration
# ============================================================================
# Backend-specific properties that can't be inferred from the backend class
# Keys are AttentionBackendEnum names (uppercase)
_BACKEND_PROPERTIES = {
"flashmla": {
"FLASHMLA": {
"query_format": "concat", # Single concatenated tensor (vs tuple)
"block_size": 64, # FlashMLA uses fixed block size
},
"flashinfer_mla": {
"block_size": 64, # FlashInfer MLA only supports 32 or 64
"FLASHMLA_SPARSE": {
"query_format": "concat", # Single concatenated tensor (vs tuple)
},
}
def _get_backend_config(backend: str) -> dict:
"""
Get backend configuration using naming conventions.
All MLA backends follow the pattern:
- Module: vllm.v1.attention.backends.mla.{backend}
- Impl: {Name}Impl
- Metadata: {Name}Metadata (or MLACommonMetadata)
- DecodeMetadata: {Name}DecodeMetadata (or MLACommonDecodeMetadata)
- MetadataBuilder: {Name}MetadataBuilder
Get backend configuration from AttentionBackendEnum.
Uses the registry to get the backend class and extract configuration
from its methods (get_impl_cls, get_builder_cls, is_sparse, etc.).
Args:
backend: Backend name matching AttentionBackendEnum exactly
(e.g., "FLASHMLA_SPARSE")
Returns:
Dict with backend configuration
"""
if backend not in _BACKEND_NAME_MAP:
raise ValueError(f"Unknown backend: {backend}")
from vllm.v1.attention.backend import MultipleOf
from vllm.v1.attention.backends.registry import AttentionBackendEnum
name = _BACKEND_NAME_MAP[backend]
try:
backend_enum = AttentionBackendEnum[backend]
backend_class = backend_enum.get_class()
except (KeyError, ValueError) as e:
valid_backends = [e.name for e in AttentionBackendEnum if e.name != "CUSTOM"]
raise ValueError(
f"Unknown backend: {backend}. "
f"Valid MLA backends: {[b for b in valid_backends if 'MLA' in b]}"
) from e
# Get block size from backend class
block_sizes = backend_class.get_supported_kernel_block_sizes()
# Use first supported block size (backends typically support one for MLA)
block_size = block_sizes[0] if block_sizes else None
if isinstance(block_size, MultipleOf):
# No fixed block size; fall back to config value
block_size = None
# Check if sparse via class method if available
is_sparse = getattr(backend_class, "is_sparse", lambda: False)()
# Get properties that can't be inferred
props = _BACKEND_PROPERTIES.get(backend, {})
# Check if backend uses common metadata (FlashInfer, CUTLASS)
uses_common = backend in ("flashinfer_mla", "cutlass_mla")
return {
"module": f"vllm.v1.attention.backends.mla.{backend}",
"impl_class": f"{name}Impl",
"metadata_class": "MLACommonMetadata" if uses_common else f"{name}Metadata",
"decode_metadata_class": "MLACommonDecodeMetadata"
if uses_common
else f"{name}DecodeMetadata",
"builder_class": f"{name}MetadataBuilder",
"backend_class": backend_class,
"impl_class": backend_class.get_impl_cls(),
"builder_class": backend_class.get_builder_cls(),
"query_format": props.get("query_format", "tuple"),
"block_size": props.get("block_size", None),
"block_size": block_size,
"is_sparse": is_sparse,
}
......@@ -447,22 +533,26 @@ def _create_backend_impl(
mla_dims: dict,
vllm_config: VllmConfig,
device: torch.device,
max_num_tokens: int = 8192,
index_topk: int | None = None,
):
"""
Create backend implementation instance.
Args:
backend_cfg: Backend configuration dict
backend_cfg: Backend configuration dict from _get_backend_config()
mla_dims: MLA dimension configuration
vllm_config: VllmConfig instance
device: Target device
max_num_tokens: Maximum number of tokens for sparse indexer buffer
index_topk: Topk value for sparse MLA backends
Returns:
Tuple of (impl, layer, builder_instance)
Tuple of (impl, layer, builder_instance, indexer)
"""
# Import backend classes
backend_module = importlib.import_module(backend_cfg["module"])
impl_class = getattr(backend_module, backend_cfg["impl_class"])
# Get classes from backend config (already resolved by _get_backend_config)
impl_class = backend_cfg["impl_class"]
builder_class = backend_cfg["builder_class"]
# Calculate scale
scale = 1.0 / np.sqrt(mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"])
......@@ -474,26 +564,44 @@ def _create_backend_impl(
v_head_dim=mla_dims["v_head_dim"],
)
# Create indexer for sparse backends
indexer = None
if backend_cfg.get("is_sparse", False):
if index_topk is None:
index_topk = 2048 # Default topk for sparse MLA
indexer = MockIndexer(
max_num_tokens=max_num_tokens,
topk_tokens=index_topk,
device=device,
)
# Build impl kwargs
impl_kwargs = {
"num_heads": mla_dims["num_q_heads"],
"head_size": mla_dims["head_dim"],
"scale": scale,
"num_kv_heads": mla_dims["num_kv_heads"],
"alibi_slopes": None,
"sliding_window": None,
"kv_cache_dtype": "auto",
"logits_soft_cap": None,
"attn_type": "decoder",
"kv_sharing_target_layer_name": None,
"q_lora_rank": None,
"kv_lora_rank": mla_dims["kv_lora_rank"],
"qk_nope_head_dim": mla_dims["qk_nope_head_dim"],
"qk_rope_head_dim": mla_dims["qk_rope_head_dim"],
"qk_head_dim": mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"],
"v_head_dim": mla_dims["v_head_dim"],
"kv_b_proj": mock_kv_b_proj,
}
# Add indexer for sparse backends
if indexer is not None:
impl_kwargs["indexer"] = indexer
# Create impl
impl = impl_class(
num_heads=mla_dims["num_q_heads"],
head_size=mla_dims["head_dim"],
scale=scale,
num_kv_heads=mla_dims["num_kv_heads"],
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="auto",
logits_soft_cap=None,
attn_type="decoder",
kv_sharing_target_layer_name=None,
q_lora_rank=None,
kv_lora_rank=mla_dims["kv_lora_rank"],
qk_nope_head_dim=mla_dims["qk_nope_head_dim"],
qk_rope_head_dim=mla_dims["qk_rope_head_dim"],
qk_head_dim=mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"],
v_head_dim=mla_dims["v_head_dim"],
kv_b_proj=mock_kv_b_proj,
)
impl = impl_class(**impl_kwargs)
# Initialize DCP attributes
if not hasattr(impl, "dcp_world_size") or impl.dcp_world_size in (None, -1):
......@@ -515,9 +623,7 @@ def _create_backend_impl(
# Create builder instance if needed
builder_instance = None
if backend_cfg["builder_class"]:
builder_class = getattr(backend_module, backend_cfg["builder_class"])
if builder_class:
# Populate static_forward_context so builder can find the layer
# MockLayer inherits from AttentionLayerBase, so isinstance checks pass
vllm_config.compilation_config.static_forward_context = {"placeholder": layer}
......@@ -529,7 +635,7 @@ def _create_backend_impl(
device=device,
)
return impl, layer, builder_instance
return impl, layer, builder_instance, indexer
# ============================================================================
......@@ -594,6 +700,7 @@ def _run_single_benchmark(
backend_cfg: dict,
mla_dims: dict,
device: torch.device,
indexer=None,
) -> BenchmarkResult:
"""
Run a single benchmark iteration.
......@@ -606,6 +713,7 @@ def _run_single_benchmark(
backend_cfg: Backend configuration dict
mla_dims: MLA dimension configuration
device: Target device
indexer: Optional MockIndexer for sparse backends
Returns:
BenchmarkResult with timing statistics
......@@ -613,7 +721,9 @@ def _run_single_benchmark(
# Parse batch spec
requests = parse_batch_spec(config.batch_spec)
q_lens = [r.q_len for r in requests]
kv_lens = [r.kv_len for r in requests]
total_q = sum(q_lens)
max_kv_len = max(kv_lens)
# Determine block size
block_size = backend_cfg["block_size"] or config.block_size
......@@ -641,13 +751,16 @@ def _run_single_benchmark(
torch.bfloat16,
)
# Fill indexer with random indices for sparse backends
is_sparse = backend_cfg.get("is_sparse", False)
if is_sparse and indexer is not None:
indexer.fill_random_indices(total_q, max_kv_len)
# Determine which forward method to use based on metadata
if metadata.decode is not None:
forward_fn = lambda: impl._forward_decode(
decode_inputs, kv_cache, metadata, layer
)
forward_fn = lambda: impl.forward_mqa(decode_inputs, kv_cache, metadata, layer)
elif metadata.prefill is not None:
forward_fn = lambda: impl._forward_prefill(
forward_fn = lambda: impl.forward_mha(
prefill_inputs["q"],
prefill_inputs["k_c_normed"],
prefill_inputs["k_pe"],
......@@ -662,7 +775,7 @@ def _run_single_benchmark(
# Warmup
for _ in range(config.warmup_iters):
forward_fn()
torch.cuda.synchronize()
torch.accelerator.synchronize()
# Benchmark
times = []
......@@ -675,7 +788,7 @@ def _run_single_benchmark(
forward_fn()
end.record()
torch.cuda.synchronize()
torch.accelerator.synchronize()
elapsed_ms = start.elapsed_time(end)
times.append(elapsed_ms / 1000.0 / config.num_layers)
......@@ -693,20 +806,26 @@ def _run_single_benchmark(
def _run_mla_benchmark_batched(
backend: str,
configs_with_params: list[tuple], # [(config, threshold, num_splits), ...]
index_topk: int = 2048,
prefill_backend: str | None = None,
) -> list[BenchmarkResult]:
"""
Unified batched MLA benchmark runner for all backends.
Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla
Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla,
flashinfer_mla_sparse, flashmla_sparse
This function reuses backend initialization across multiple benchmarks
to avoid setup/teardown overhead.
Args:
backend: Backend name
backend: Backend name (decode backend used for impl construction)
configs_with_params: List of (config, threshold, num_splits) tuples
- threshold: reorder_batch_threshold (FlashAttn/FlashMLA only)
- num_splits: num_kv_splits (CUTLASS only)
index_topk: Topk value for sparse MLA backends (default 2048)
prefill_backend: Prefill backend name (e.g., "fa3", "fa4").
When set, forces the specified FlashAttention version for prefill.
Returns:
List of BenchmarkResult objects
......@@ -716,7 +835,7 @@ def _run_mla_benchmark_batched(
backend_cfg = _get_backend_config(backend)
device = torch.device(configs_with_params[0][0].device)
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
# Determine block size
config_block_size = configs_with_params[0][0].block_size
......@@ -730,21 +849,75 @@ def _run_mla_benchmark_batched(
if mla_dims is None:
mla_dims = setup_mla_dims("deepseek-v3")
# Determine if this is a sparse backend
is_sparse = backend_cfg.get("is_sparse", False)
# Create and set vLLM config for MLA (reused across all benchmarks)
vllm_config = create_minimal_vllm_config(
model_name="deepseek-v3", # Used only for model path
block_size=block_size,
mla_dims=mla_dims, # Use custom dims from config or default
index_topk=index_topk if is_sparse else None,
prefill_backend=prefill_backend,
)
results = []
with set_current_vllm_config(vllm_config):
# Create backend impl, layer, and builder (reused across benchmarks)
impl, layer, builder_instance = _create_backend_impl(
backend_cfg, mla_dims, vllm_config, device
# Clear cached prefill backend detection functions so they re-evaluate
# with the current VllmConfig. These are @functools.cache decorated and
# would otherwise return stale results from a previous backend's config.
from vllm.model_executor.layers.attention.mla_attention import (
use_cudnn_prefill,
use_flashinfer_prefill,
use_trtllm_ragged_deepseek_prefill,
)
use_flashinfer_prefill.cache_clear()
use_cudnn_prefill.cache_clear()
use_trtllm_ragged_deepseek_prefill.cache_clear()
# Create backend impl, layer, builder, and indexer (reused across benchmarks)
impl, layer, builder_instance, indexer = _create_backend_impl(
backend_cfg,
mla_dims,
vllm_config,
device,
index_topk=index_topk if is_sparse else None,
)
# Verify the actual prefill backend matches what was requested
if prefill_backend is not None:
prefill_cfg = get_prefill_backend_config(prefill_backend)
fa_version = prefill_cfg["flash_attn_version"]
if fa_version is not None:
# FA backend: verify the impl's FA version
actual_fa_version = getattr(impl, "vllm_flash_attn_version", None)
if actual_fa_version != fa_version:
raise RuntimeError(
f"Prefill backend '{prefill_backend}' requested FA "
f"version {fa_version}, but the impl is using FA "
f"version {actual_fa_version}. Check "
f"vllm/v1/attention/backends/fa_utils.py."
)
else:
# Non-FA backend: verify the builder picked the right path
expected_flags = {
"flashinfer": "_use_fi_prefill",
"cudnn": "_use_cudnn_prefill",
"trtllm": "_use_trtllm_ragged_prefill",
}
flag_name = expected_flags.get(prefill_backend)
if flag_name and not getattr(builder_instance, flag_name, False):
raise RuntimeError(
f"Prefill backend '{prefill_backend}' was requested "
f"but the metadata builder did not enable it. This "
f"usually means a dependency is missing (e.g., "
f"flashinfer not installed) or the platform doesn't "
f"support it."
)
# Run each benchmark with the shared impl
for config, threshold, num_splits in configs_with_params:
# Set threshold for this benchmark (FlashAttn/FlashMLA only)
......@@ -768,6 +941,7 @@ def _run_mla_benchmark_batched(
backend_cfg,
mla_dims,
device,
indexer=indexer,
)
results.append(result)
......@@ -793,20 +967,27 @@ def run_mla_benchmark(
config,
reorder_batch_threshold: int | None = None,
num_kv_splits: int | None = None,
index_topk: int = 2048,
prefill_backend: str | None = None,
) -> BenchmarkResult | list[BenchmarkResult]:
"""
Unified MLA benchmark runner for all backends.
Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla
Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla,
flashinfer_mla_sparse, flashmla_sparse
Always uses batched execution internally for optimal performance.
Args:
backend: Backend name (flashattn_mla, flashmla, flashinfer_mla, cutlass_mla)
backend: Backend name (flashattn_mla, flashmla, flashinfer_mla, cutlass_mla,
flashinfer_mla_sparse, flashmla_sparse)
config: BenchmarkConfig or list of (BenchmarkConfig, param) tuples
reorder_batch_threshold: Threshold override for FlashAttn/FlashMLA
(single config mode only)
num_kv_splits: Number of KV splits for CUTLASS (single config mode only)
index_topk: Topk value for sparse MLA backends (default 2048)
prefill_backend: Prefill backend name (e.g., "fa3", "fa4").
When set, forces the specified FlashAttention version for prefill.
Returns:
BenchmarkResult (single mode) or list of BenchmarkResult (batched mode)
......@@ -816,9 +997,9 @@ def run_mla_benchmark(
# Already in batched format
if len(config) > 0 and isinstance(config[0], tuple):
# Format: [(cfg, param), ...] where param is threshold or num_splits
if backend in ("flashattn_mla", "flashmla"):
if backend in ("flashattn_mla", "flashmla", "flashmla_sparse"):
configs_with_params = [(cfg, param, None) for cfg, param in config]
else: # cutlass_mla or flashinfer_mla
else: # cutlass_mla, flashinfer_mla, or sparse backends
configs_with_params = [(cfg, None, param) for cfg, param in config]
else:
# Format: [cfg, ...] - just configs
......@@ -830,7 +1011,9 @@ def run_mla_benchmark(
return_single = True
# Use unified batched execution
results = _run_mla_benchmark_batched(backend, configs_with_params)
results = _run_mla_benchmark_batched(
backend, configs_with_params, index_topk, prefill_backend=prefill_backend
)
# Return single result or list based on input
return results[0] if return_single else results
......@@ -8,7 +8,9 @@ This module provides helpers for running standard attention backends
(FlashAttention, Triton, FlashInfer) with real vLLM integration.
"""
import logging
import types
from contextlib import contextmanager
import numpy as np
import torch
......@@ -24,8 +26,13 @@ from vllm.config import (
ParallelConfig,
SchedulerConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
get_kv_cache_layout,
set_kv_cache_layout,
)
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec
# ============================================================================
......@@ -33,37 +40,41 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec
# ============================================================================
_BACKEND_CONFIG = {
"flash": {
"module": "vllm.v1.attention.backends.flash_attn",
"backend_class": "FlashAttentionBackend",
"dtype": torch.float16,
"cache_layout": "standard",
# ^ [2, num_blocks, block_size, num_kv_heads, head_dim]
},
"triton": {
"module": "vllm.v1.attention.backends.triton_attn",
"backend_class": "TritonAttentionBackend",
"dtype": torch.float32,
"cache_layout": "standard",
},
"flashinfer": {
"module": "vllm.v1.attention.backends.flashinfer",
"backend_class": "FlashInferBackend",
"dtype": torch.float16,
"cache_layout": "flashinfer",
# ^ [num_blocks, 2, block_size, num_kv_heads, head_dim]
},
}
def _get_backend_config(backend: str) -> dict:
"""
Get backend configuration from AttentionBackendEnum.
Args:
backend: Backend name matching AttentionBackendEnum exactly
(e.g., "FLASH_ATTN", "TRITON_ATTN", "FLASHINFER")
Returns:
Dict with backend_class
"""
from vllm.v1.attention.backends.registry import AttentionBackendEnum
def _get_backend_config(backend: str) -> dict:
if backend not in _BACKEND_CONFIG:
try:
backend_enum = AttentionBackendEnum[backend]
backend_class = backend_enum.get_class()
except (KeyError, ValueError) as e:
valid_backends = [b.name for b in AttentionBackendEnum if b.name != "CUSTOM"]
raise ValueError(
f"Unknown backend: {backend}. "
f"Available: {', '.join(_BACKEND_CONFIG.keys())}"
)
return _BACKEND_CONFIG[backend]
f"Unknown backend: {backend}. Valid backends: {valid_backends}"
) from e
return {"backend_class": backend_class}
@contextmanager
def log_warnings_and_errors_only():
"""Temporarily set vLLM logger to WARNING level."""
logger = logging.getLogger("vllm")
old_level = logger.level
logger.setLevel(logging.WARNING)
try:
yield
finally:
logger.setLevel(old_level)
# ============================================================================
......@@ -88,11 +99,7 @@ def _build_common_attn_metadata(
query_start_loc_cpu = query_start_loc.cpu()
seq_lens = torch.tensor(kv_lens, dtype=torch.int32, device=device)
seq_lens_cpu = seq_lens.cpu()
max_seq_len = int(seq_lens_cpu.max())
context_lens = [kv - q for kv, q in zip(kv_lens, q_lens)]
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
max_seq_len = int(seq_lens.max().item())
max_blocks = (max(kv_lens) + block_size - 1) // block_size
num_blocks = batch_size * max_blocks
......@@ -107,8 +114,6 @@ def _build_common_attn_metadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=batch_size,
num_actual_tokens=total_tokens,
max_query_len=max_query_len,
......@@ -121,7 +126,6 @@ def _build_common_attn_metadata(
def _create_vllm_config(
config: BenchmarkConfig,
dtype: torch.dtype,
max_num_blocks: int,
) -> VllmConfig:
"""Create a VllmConfig for benchmarking with mock model methods."""
......@@ -129,7 +133,7 @@ def _create_vllm_config(
model="meta-llama/Meta-Llama-3-8B",
tokenizer="meta-llama/Meta-Llama-3-8B",
trust_remote_code=False,
dtype=dtype,
dtype="auto", # Use model's native dtype
seed=0,
max_model_len=1024,
)
......@@ -137,7 +141,6 @@ def _create_vllm_config(
cache_config = CacheConfig(
block_size=config.block_size,
cache_dtype="auto",
swap_space=0,
)
cache_config.num_gpu_blocks = max_num_blocks
cache_config.num_cpu_blocks = 0
......@@ -198,15 +201,12 @@ def _create_backend_impl(
backend_cfg: dict,
config: BenchmarkConfig,
device: torch.device,
dtype: torch.dtype,
):
"""Create backend implementation instance."""
import importlib
backend_module = importlib.import_module(backend_cfg["module"])
backend_class = getattr(backend_module, backend_cfg["backend_class"])
backend_class = backend_cfg["backend_class"]
scale = get_attention_scale(config.head_dim)
dtype = backend_cfg["dtype"]
impl = backend_class.get_impl_cls()(
num_heads=config.num_q_heads,
......@@ -227,7 +227,7 @@ def _create_backend_impl(
layer = MockLayer(device, kv_cache_spec=kv_cache_spec)
return backend_class, impl, layer, dtype
return backend_class, impl, layer
def _create_metadata_builder(
......@@ -235,11 +235,44 @@ def _create_metadata_builder(
kv_cache_spec: FullAttentionSpec,
vllm_config: VllmConfig,
device: torch.device,
backend_name: str = "",
):
"""Create metadata builder instance."""
return backend_class.get_builder_cls()(
layer_names = ["layer_0"]
builder_cls = backend_class.get_builder_cls()
# Flashinfer needs get_per_layer_parameters mocked since we don't have
# real model layers registered
if backend_name == "FLASHINFER":
import unittest.mock
from vllm.v1.attention.backends.utils import PerLayerParameters
def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls):
head_size = vllm_config.model_config.get_head_size()
return {
layer_name: PerLayerParameters(
window_left=-1, # No sliding window
logits_soft_cap=0.0, # No soft cap
sm_scale=1.0 / (head_size**0.5), # Standard scale
)
for layer_name in layer_names
}
with unittest.mock.patch(
"vllm.v1.attention.backends.flashinfer.get_per_layer_parameters",
mock_get_per_layer_parameters,
):
return builder_cls(
kv_cache_spec=kv_cache_spec,
layer_names=layer_names,
vllm_config=vllm_config,
device=device,
)
return builder_cls(
kv_cache_spec=kv_cache_spec,
layer_names=["layer_0"],
layer_names=layer_names,
vllm_config=vllm_config,
device=device,
)
......@@ -281,39 +314,44 @@ def _create_input_tensors(
def _create_kv_cache(
config: BenchmarkConfig,
max_num_blocks: int,
cache_layout: str,
backend_class,
device: torch.device,
dtype: torch.dtype,
) -> list:
"""Create KV cache tensors for all layers."""
if cache_layout == "flashinfer":
# FlashInfer layout: [num_blocks, 2, block_size, num_kv_heads, head_dim]
cache_list = [
torch.zeros(
max_num_blocks,
2,
config.block_size,
config.num_kv_heads,
config.head_dim,
device=device,
dtype=dtype,
)
for _ in range(config.num_layers)
]
else:
# Standard layout: [2, num_blocks, block_size, num_kv_heads, head_dim]
cache_list = [
torch.zeros(
2,
max_num_blocks,
config.block_size,
config.num_kv_heads,
config.head_dim,
device=device,
dtype=dtype,
)
for _ in range(config.num_layers)
]
"""Create KV cache tensors for all layers using the backend's methods.
Uses the backend's get_kv_cache_shape() and get_kv_cache_stride_order()
to create the cache with the correct shape and memory layout.
"""
# Get the logical shape from the backend
cache_shape = backend_class.get_kv_cache_shape(
num_blocks=max_num_blocks,
block_size=config.block_size,
num_kv_heads=config.num_kv_heads,
head_size=config.head_dim,
)
# Get the stride order for custom memory layout
try:
stride_order = backend_class.get_kv_cache_stride_order()
assert len(stride_order) == len(cache_shape)
except (AttributeError, NotImplementedError):
stride_order = tuple(range(len(cache_shape)))
# Permute shape to physical layout order
physical_shape = tuple(cache_shape[i] for i in stride_order)
# Compute inverse permutation to get back to logical view
inv_order = [stride_order.index(i) for i in range(len(stride_order))]
cache_list = []
for _ in range(config.num_layers):
# Allocate in physical layout order (contiguous in memory)
cache = torch.zeros(*physical_shape, device=device, dtype=dtype)
# Permute to logical view
cache = cache.permute(*inv_order)
cache_list.append(cache)
return cache_list
......@@ -352,7 +390,7 @@ def _run_single_benchmark(
attn_metadata,
output=out,
)
torch.cuda.synchronize()
torch.accelerator.synchronize()
# Benchmark
times = []
......@@ -373,15 +411,15 @@ def _run_single_benchmark(
)
end.record()
torch.cuda.synchronize()
torch.accelerator.synchronize()
elapsed_ms = start.elapsed_time(end)
times.append(elapsed_ms / 1000.0 / config.num_layers) # seconds per layer
mem_stats = {}
if config.profile_memory:
mem_stats = {
"allocated_mb": torch.cuda.memory_allocated(device) / 1024**2,
"reserved_mb": torch.cuda.memory_reserved(device) / 1024**2,
"allocated_mb": torch.accelerator.memory_allocated(device) / 1024**2,
"reserved_mb": torch.accelerator.memory_reserved(device) / 1024**2,
}
return times, mem_stats
......@@ -396,7 +434,7 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
"""
Run standard attention benchmark with real kernels.
Supports: flash, triton, flashinfer
Supports: FLASH_ATTN, TRITON_ATTN, FLASHINFER
Args:
config: Benchmark configuration
......@@ -405,66 +443,85 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
BenchmarkResult with timing and memory statistics
"""
device = torch.device(config.device)
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
backend_cfg = _get_backend_config(config.backend)
requests = parse_batch_spec(config.batch_spec)
if config.backend == "flashinfer":
if config.backend == "FLASHINFER":
requests = reorder_for_flashinfer(requests)
q_lens = [r.q_len for r in requests]
kv_lens = [r.kv_len for r in requests]
total_q = sum(q_lens)
max_kv = max(kv_lens)
batch_size = len(q_lens)
max_num_blocks = (max_kv + config.block_size - 1) // config.block_size
backend_class, impl, layer, dtype = _create_backend_impl(
backend_cfg, config, device
)
# Calculate total blocks needed: batch_size * max_blocks_per_request
max_blocks_per_request = (max_kv + config.block_size - 1) // config.block_size
max_num_blocks = batch_size * max_blocks_per_request
# Suppress vLLM logs during setup to reduce spam
with log_warnings_and_errors_only():
# Create vllm_config first - uses model's native dtype via "auto"
vllm_config = _create_vllm_config(config, max_num_blocks)
dtype = vllm_config.model_config.dtype
# Wrap everything in set_current_vllm_config context
# This is required for backends like flashinfer that need global config
with set_current_vllm_config(vllm_config):
backend_class, impl, layer = _create_backend_impl(
backend_cfg, config, device, dtype
)
common_metadata = _build_common_attn_metadata(
q_lens, kv_lens, config.block_size, device
)
# Set KV cache layout if the backend requires a specific one
# (e.g., FlashInfer requires HND on SM100/Blackwell for TRTLLM attention)
required_layout = backend_class.get_required_kv_cache_layout()
if required_layout is not None:
set_kv_cache_layout(required_layout)
get_kv_cache_layout.cache_clear()
kv_cache_spec = FullAttentionSpec(
block_size=config.block_size,
num_kv_heads=config.num_kv_heads,
head_size=config.head_dim,
dtype=dtype,
)
common_metadata = _build_common_attn_metadata(
q_lens, kv_lens, config.block_size, device
)
vllm_config = _create_vllm_config(config, dtype, max_num_blocks)
kv_cache_spec = FullAttentionSpec(
block_size=config.block_size,
num_kv_heads=config.num_kv_heads,
head_size=config.head_dim,
dtype=dtype,
)
builder = _create_metadata_builder(
backend_class, kv_cache_spec, vllm_config, device
)
builder = _create_metadata_builder(
backend_class, kv_cache_spec, vllm_config, device, config.backend
)
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_metadata,
)
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_metadata,
)
q_list, k_list, v_list = _create_input_tensors(config, total_q, device, dtype)
q_list, k_list, v_list = _create_input_tensors(
config, total_q, device, dtype
)
cache_list = _create_kv_cache(
config, max_num_blocks, backend_cfg["cache_layout"], device, dtype
)
cache_list = _create_kv_cache(
config, max_num_blocks, backend_class, device, dtype
)
times, mem_stats = _run_single_benchmark(
config,
impl,
layer,
q_list,
k_list,
v_list,
cache_list,
attn_metadata,
device,
dtype,
)
times, mem_stats = _run_single_benchmark(
config,
impl,
layer,
q_list,
k_list,
v_list,
cache_list,
attn_metadata,
device,
dtype,
)
mean_time = np.mean(times)
throughput = total_q / mean_time if mean_time > 0 else 0
......
......@@ -41,7 +41,7 @@ MODEL=meta-llama/Llama-3.3-70B-Instruct SYSTEM=TPU TP=8 DOWNLOAD_DIR='' INPUT_LE
| --- | --- | --- |
| `BASE` | **Required.** The absolute path to the parent directory of your vLLM repository directory. | `"$HOME"` |
| `MODEL` | **Required.** The Hugging Face model identifier to be served by vllm. | `"meta-llama/Llama-3.1-8B-Instruct"` |
| `SYSTEM`| **Required.** The hardware you are running on. Choices: `TPU` or `GPU`. (For other systems, it might not support saving profiles) | `"TPU"` |
| `SYSTEM` | **Required.** The hardware you are running on. Choices: `TPU` or `GPU`. (For other systems, it might not support saving profiles) | `"TPU"` |
| `TP` | **Required.** The tensor-parallelism size. | `1` |
| `DOWNLOAD_DIR` | **Required.** Directory to download and load model weights from. | `""` (default download path) |
| `INPUT_LEN` | **Required.** Request input length. | `4000` |
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment