Unverified Commit f444c05c authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention] Use FA4 for MLA prefill (#34732)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 85199f96
...@@ -59,7 +59,9 @@ def run_mla_benchmark(config: BenchmarkConfig, **kwargs) -> BenchmarkResult: ...@@ -59,7 +59,9 @@ def run_mla_benchmark(config: BenchmarkConfig, **kwargs) -> BenchmarkResult:
"""Run MLA benchmark with appropriate backend.""" """Run MLA benchmark with appropriate backend."""
from mla_runner import run_mla_benchmark as run_mla from mla_runner import run_mla_benchmark as run_mla
return run_mla(config.backend, config, **kwargs) return run_mla(
config.backend, config, prefill_backend=config.prefill_backend, **kwargs
)
def run_benchmark(config: BenchmarkConfig, **kwargs) -> BenchmarkResult: def run_benchmark(config: BenchmarkConfig, **kwargs) -> BenchmarkResult:
...@@ -440,14 +442,21 @@ def main(): ...@@ -440,14 +442,21 @@ def main():
# Backend selection # Backend selection
parser.add_argument( parser.add_argument(
"--backends", "--backends",
"--decode-backends",
nargs="+", nargs="+",
help="Backends to benchmark (flash, triton, flashinfer, cutlass_mla, " help="Decode backends to benchmark (flash, triton, flashinfer, cutlass_mla, "
"flashinfer_mla, flashattn_mla, flashmla)", "flashinfer_mla, flashattn_mla, flashmla)",
) )
parser.add_argument( parser.add_argument(
"--backend", "--backend",
help="Single backend (alternative to --backends)", help="Single backend (alternative to --backends)",
) )
parser.add_argument(
"--prefill-backends",
nargs="+",
help="Prefill backends to compare (fa2, fa3, fa4). "
"Uses the first decode backend for impl construction.",
)
# Batch specifications # Batch specifications
parser.add_argument( parser.add_argument(
...@@ -502,7 +511,7 @@ def main(): ...@@ -502,7 +511,7 @@ def main():
# Override args with YAML values, but CLI args take precedence # Override args with YAML values, but CLI args take precedence
# Check if CLI provided backends (they would be non-None and not default) # Check if CLI provided backends (they would be non-None and not default)
cli_backends_provided = args.backends is not None or args.backend is not None cli_backends_provided = args.backend is not None or args.backends is not None
# Backend(s) - only use YAML if CLI didn't specify # Backend(s) - only use YAML if CLI didn't specify
if not cli_backends_provided: if not cli_backends_provided:
...@@ -512,6 +521,12 @@ def main(): ...@@ -512,6 +521,12 @@ def main():
elif "backends" in yaml_config: elif "backends" in yaml_config:
args.backends = yaml_config["backends"] args.backends = yaml_config["backends"]
args.backend = None args.backend = None
elif "decode_backends" in yaml_config:
args.backends = yaml_config["decode_backends"]
args.backend = None
# Prefill backends (e.g., ["fa3", "fa4"])
args.prefill_backends = yaml_config.get("prefill_backends", None)
# Check for special modes # Check for special modes
if "mode" in yaml_config: if "mode" in yaml_config:
...@@ -613,7 +628,10 @@ def main(): ...@@ -613,7 +628,10 @@ def main():
# Determine backends # Determine backends
backends = args.backends or ([args.backend] if args.backend else ["flash"]) backends = args.backends or ([args.backend] if args.backend else ["flash"])
prefill_backends = getattr(args, "prefill_backends", None)
console.print(f"Backends: {', '.join(backends)}") console.print(f"Backends: {', '.join(backends)}")
if prefill_backends:
console.print(f"Prefill backends: {', '.join(prefill_backends)}")
console.print(f"Batch specs: {', '.join(args.batch_specs)}") console.print(f"Batch specs: {', '.join(args.batch_specs)}")
console.print() console.print()
...@@ -850,6 +868,12 @@ def main(): ...@@ -850,6 +868,12 @@ def main():
else: else:
# Normal mode: compare backends # Normal mode: compare backends
decode_results = []
prefill_results = []
# Run decode backend comparison
if not prefill_backends:
# No prefill backends specified: compare decode backends as before
total = len(backends) * len(args.batch_specs) total = len(backends) * len(args.batch_specs)
with tqdm(total=total, desc="Benchmarking") as pbar: with tqdm(total=total, desc="Benchmarking") as pbar:
...@@ -870,17 +894,67 @@ def main(): ...@@ -870,17 +894,67 @@ def main():
) )
result = run_benchmark(config) result = run_benchmark(config)
all_results.append(result) decode_results.append(result)
if not result.success: if not result.success:
console.print(f"[red]Error {backend} {spec}: {result.error}[/]") console.print(
f"[red]Error {backend} {spec}: {result.error}[/]"
)
pbar.update(1) pbar.update(1)
# Display results
console.print("\n[bold green]Results:[/]") console.print("\n[bold green]Results:[/]")
formatter = ResultsFormatter(console) formatter = ResultsFormatter(console)
formatter.print_table(all_results, backends) formatter.print_table(decode_results, backends)
# Run prefill backend comparison
if prefill_backends:
# Use first decode backend for impl construction
decode_backend = backends[0]
total = len(prefill_backends) * len(args.batch_specs)
console.print(
f"[yellow]Prefill comparison mode: "
f"using {decode_backend} for decode impl[/]"
)
with tqdm(total=total, desc="Prefill benchmarking") as pbar:
for spec in args.batch_specs:
for pb in prefill_backends:
config = BenchmarkConfig(
backend=decode_backend,
batch_spec=spec,
num_layers=args.num_layers,
head_dim=args.head_dim,
num_q_heads=args.num_q_heads,
num_kv_heads=args.num_kv_heads,
block_size=args.block_size,
device=args.device,
repeats=args.repeats,
warmup_iters=args.warmup_iters,
profile_memory=args.profile_memory,
prefill_backend=pb,
)
result = run_benchmark(config)
# Label result with prefill backend name for display
labeled_config = replace(result.config, backend=pb)
result = replace(result, config=labeled_config)
prefill_results.append(result)
if not result.success:
console.print(f"[red]Error {pb} {spec}: {result.error}[/]")
pbar.update(1)
console.print("\n[bold green]Prefill Backend Results:[/]")
formatter = ResultsFormatter(console)
formatter.print_table(
prefill_results, prefill_backends, compare_to_fastest=True
)
all_results = decode_results + prefill_results
# Save results # Save results
if all_results: if all_results:
......
...@@ -77,6 +77,7 @@ class MockKVBProj: ...@@ -77,6 +77,7 @@ class MockKVBProj:
self.qk_nope_head_dim = qk_nope_head_dim self.qk_nope_head_dim = qk_nope_head_dim
self.v_head_dim = v_head_dim self.v_head_dim = v_head_dim
self.out_dim = qk_nope_head_dim + v_head_dim self.out_dim = qk_nope_head_dim + v_head_dim
self.weight = torch.empty(0, dtype=torch.bfloat16)
def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor]: def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor]:
""" """
...@@ -213,6 +214,7 @@ class BenchmarkConfig: ...@@ -213,6 +214,7 @@ class BenchmarkConfig:
use_cuda_graphs: bool = False use_cuda_graphs: bool = False
# MLA-specific # MLA-specific
prefill_backend: str | None = None
kv_lora_rank: int | None = None kv_lora_rank: int | None = None
qk_nope_head_dim: int | None = None qk_nope_head_dim: int | None = None
qk_rope_head_dim: int | None = None qk_rope_head_dim: int | None = None
......
# MLA prefill-only benchmark configuration for sparse backends # MLA prefill backend comparison
#
# Compares all available MLA prefill backends:
# FA backends: fa2, fa3, fa4 (FlashAttention versions)
# Non-FA: flashinfer, cudnn, trtllm (Blackwell-only, require flashinfer)
#
# Uses cutlass_mla as the decode backend for impl construction
# (only the prefill path is exercised).
#
# Backends that aren't available on the current platform will report errors
# in the results table (e.g., fa3 on Blackwell, cudnn without artifactory).
#
# Usage:
# python benchmark.py --config configs/mla_prefill.yaml
description: "MLA prefill backend comparison"
model: model:
name: "deepseek-v3" name: "deepseek-v3"
...@@ -12,20 +27,25 @@ model: ...@@ -12,20 +27,25 @@ model:
v_head_dim: 128 v_head_dim: 128
block_size: 128 block_size: 128
# Model parameter sweep: simulate tensor parallelism by varying num_q_heads # model:
# TP=1: 128 heads, TP=2: 64 heads, TP=4: 32 heads, TP=8: 16 heads # name: "deepseek-v2-lite"
model_parameter_sweep: # num_layers: 27
param_name: "num_q_heads" # num_q_heads: 16
values: [128, 64, 32, 16] # num_kv_heads: 1
label_format: "{backend}_{value}h" # head_dim: 576
# kv_lora_rank: 512
# qk_nope_head_dim: 128
# qk_rope_head_dim: 64
# v_head_dim: 128
# block_size: 128
batch_specs: batch_specs:
# Pure prefill # Pure prefill
- "1q512" - "q512"
- "1q1k" - "q1k"
- "1q2k" - "q2k"
- "1q4k" - "q4k"
- "1q8k" - "q8k"
# Batched pure prefill # Batched pure prefill
- "2q512" - "2q512"
...@@ -44,19 +64,63 @@ batch_specs: ...@@ -44,19 +64,63 @@ batch_specs:
- "8q4k" - "8q4k"
- "8q8k" - "8q8k"
# Extend # Chunked prefill / extend
- "1q512s4k" # Short context
- "1q512s8k" - "q128s1k"
- "1q1ks8k" - "q256s2k"
- "1q2ks8k" - "q512s4k"
- "1q2ks16k" - "q1ks4k"
- "1q4ks16k" - "q2ks8k"
- "2q128s1k"
- "2q256s2k"
- "2q512s4k"
- "2q1ks4k"
- "2q2ks8k"
- "4q128s1k"
- "4q256s2k"
- "4q512s4k"
- "4q1ks4k"
- "4q2ks8k"
- "8q128s1k"
- "8q256s2k"
- "8q512s4k"
- "8q1ks4k"
# Medium context
- "q128s16k"
- "q512s16k"
- "q1ks16k"
- "q2ks16k"
- "2q128s16k"
- "2q512s16k"
- "2q1ks16k"
- "2q2ks16k"
- "4q128s16k"
- "4q512s16k"
- "4q1ks16k"
- "4q2ks16k"
# Long context
- "q128s64k"
- "q512s64k"
- "q1ks64k"
- "q2ks64k"
- "2q128s64k"
- "2q512s64k"
- "2q1ks64k"
- "2q2ks64k"
decode_backends:
- CUTLASS_MLA
backends: prefill_backends:
- FLASHMLA_SPARSE - fa2
- FLASHINFER_MLA_SPARSE - fa3
- fa4
- flashinfer
- cudnn
- trtllm
device: "cuda:0" device: "cuda:0"
repeats: 10 repeats: 20
warmup_iters: 3 warmup_iters: 5
profile_memory: true
# MLA prefill-only benchmark configuration for sparse backends
model:
name: "deepseek-v3"
num_layers: 60
num_q_heads: 128
num_kv_heads: 1
head_dim: 576
kv_lora_rank: 512
qk_nope_head_dim: 128
qk_rope_head_dim: 64
v_head_dim: 128
block_size: 128
# Model parameter sweep: simulate tensor parallelism by varying num_q_heads
# TP=1: 128 heads, TP=2: 64 heads, TP=4: 32 heads, TP=8: 16 heads
model_parameter_sweep:
param_name: "num_q_heads"
values: [128, 64, 32, 16]
label_format: "{backend}_{value}h"
batch_specs:
# Pure prefill
- "1q512"
- "1q1k"
- "1q2k"
- "1q4k"
- "1q8k"
# Batched pure prefill
- "2q512"
- "2q1k"
- "2q2k"
- "2q4k"
- "2q8k"
- "4q512"
- "4q1k"
- "4q2k"
- "4q4k"
- "4q8k"
- "8q512"
- "8q1k"
- "8q2k"
- "8q4k"
- "8q8k"
# Extend
- "1q512s4k"
- "1q512s8k"
- "1q1ks8k"
- "1q2ks8k"
- "1q2ks16k"
- "1q4ks16k"
backends:
- FLASHMLA_SPARSE
- FLASHINFER_MLA_SPARSE
device: "cuda:0"
repeats: 10
warmup_iters: 3
profile_memory: true
...@@ -62,6 +62,7 @@ def create_minimal_vllm_config( ...@@ -62,6 +62,7 @@ def create_minimal_vllm_config(
max_num_seqs: int = 256, max_num_seqs: int = 256,
mla_dims: dict | None = None, mla_dims: dict | None = None,
index_topk: int | None = None, index_topk: int | None = None,
prefill_backend: str | None = None,
) -> VllmConfig: ) -> VllmConfig:
""" """
Create minimal VllmConfig for MLA benchmarks. Create minimal VllmConfig for MLA benchmarks.
...@@ -75,6 +76,9 @@ def create_minimal_vllm_config( ...@@ -75,6 +76,9 @@ def create_minimal_vllm_config(
setup_mla_dims(model_name) setup_mla_dims(model_name)
index_topk: Optional topk value for sparse MLA backends. If provided, index_topk: Optional topk value for sparse MLA backends. If provided,
the config will include index_topk for sparse attention. the config will include index_topk for sparse attention.
prefill_backend: Prefill backend name (e.g., "fa3", "fa4", "flashinfer",
"cudnn", "trtllm"). Configures the attention config to
force the specified prefill backend.
Returns: Returns:
VllmConfig for benchmarking VllmConfig for benchmarking
...@@ -163,7 +167,7 @@ def create_minimal_vllm_config( ...@@ -163,7 +167,7 @@ def create_minimal_vllm_config(
compilation_config = CompilationConfig() compilation_config = CompilationConfig()
return VllmConfig( vllm_config = VllmConfig(
model_config=model_config, model_config=model_config,
cache_config=cache_config, cache_config=cache_config,
parallel_config=parallel_config, parallel_config=parallel_config,
...@@ -171,9 +175,84 @@ def create_minimal_vllm_config( ...@@ -171,9 +175,84 @@ def create_minimal_vllm_config(
compilation_config=compilation_config, compilation_config=compilation_config,
) )
if prefill_backend is not None:
prefill_cfg = get_prefill_backend_config(prefill_backend)
if prefill_cfg["flash_attn_version"] is not None:
vllm_config.attention_config.flash_attn_version = prefill_cfg[
"flash_attn_version"
]
vllm_config.attention_config.disable_flashinfer_prefill = prefill_cfg[
"disable_flashinfer_prefill"
]
vllm_config.attention_config.use_cudnn_prefill = prefill_cfg[
"use_cudnn_prefill"
]
vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill = prefill_cfg[
"use_trtllm_ragged_deepseek_prefill"
]
return vllm_config
# ============================================================================ # ============================================================================
# Backend Configuration # Prefill Backend Configuration
# ============================================================================
# Maps prefill backend names to attention config overrides.
# FA backends set flash_attn_version and disable non-FA paths.
# Non-FA backends enable their specific path and disable others.
_PREFILL_BACKEND_CONFIG: dict[str, dict] = {
"fa2": {
"flash_attn_version": 2,
"disable_flashinfer_prefill": True,
"use_cudnn_prefill": False,
"use_trtllm_ragged_deepseek_prefill": False,
},
"fa3": {
"flash_attn_version": 3,
"disable_flashinfer_prefill": True,
"use_cudnn_prefill": False,
"use_trtllm_ragged_deepseek_prefill": False,
},
"fa4": {
"flash_attn_version": 4,
"disable_flashinfer_prefill": True,
"use_cudnn_prefill": False,
"use_trtllm_ragged_deepseek_prefill": False,
},
"flashinfer": {
"flash_attn_version": None,
"disable_flashinfer_prefill": False,
"use_cudnn_prefill": False,
"use_trtllm_ragged_deepseek_prefill": False,
},
"cudnn": {
"flash_attn_version": None,
"disable_flashinfer_prefill": True,
"use_cudnn_prefill": True,
"use_trtllm_ragged_deepseek_prefill": False,
},
"trtllm": {
"flash_attn_version": None,
"disable_flashinfer_prefill": True,
"use_cudnn_prefill": False,
"use_trtllm_ragged_deepseek_prefill": True,
},
}
def get_prefill_backend_config(prefill_backend: str) -> dict:
"""Get attention config overrides for a prefill backend."""
if prefill_backend not in _PREFILL_BACKEND_CONFIG:
raise ValueError(
f"Unknown prefill backend: {prefill_backend!r}. "
f"Available: {list(_PREFILL_BACKEND_CONFIG.keys())}"
)
return _PREFILL_BACKEND_CONFIG[prefill_backend]
# ============================================================================
# Decode Backend Configuration
# ============================================================================ # ============================================================================
...@@ -203,6 +282,7 @@ def _get_backend_config(backend: str) -> dict: ...@@ -203,6 +282,7 @@ def _get_backend_config(backend: str) -> dict:
Returns: Returns:
Dict with backend configuration Dict with backend configuration
""" """
from vllm.v1.attention.backend import MultipleOf
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
try: try:
...@@ -219,8 +299,8 @@ def _get_backend_config(backend: str) -> dict: ...@@ -219,8 +299,8 @@ def _get_backend_config(backend: str) -> dict:
block_sizes = backend_class.get_supported_kernel_block_sizes() block_sizes = backend_class.get_supported_kernel_block_sizes()
# Use first supported block size (backends typically support one for MLA) # Use first supported block size (backends typically support one for MLA)
block_size = block_sizes[0] if block_sizes else None block_size = block_sizes[0] if block_sizes else None
if hasattr(block_size, "value"): if isinstance(block_size, MultipleOf):
# Handle MultipleOf enum # No fixed block size; fall back to config value
block_size = None block_size = None
# Check if sparse via class method if available # Check if sparse via class method if available
...@@ -676,16 +756,11 @@ def _run_single_benchmark( ...@@ -676,16 +756,11 @@ def _run_single_benchmark(
if is_sparse and indexer is not None: if is_sparse and indexer is not None:
indexer.fill_random_indices(total_q, max_kv_len) indexer.fill_random_indices(total_q, max_kv_len)
# Determine which forward method to use # Determine which forward method to use based on metadata
if is_sparse: if metadata.decode is not None:
# Sparse backends use forward_mqa
forward_fn = lambda: impl.forward_mqa(decode_inputs, kv_cache, metadata, layer) forward_fn = lambda: impl.forward_mqa(decode_inputs, kv_cache, metadata, layer)
elif metadata.decode is not None:
forward_fn = lambda: impl._forward_decode(
decode_inputs, kv_cache, metadata, layer
)
elif metadata.prefill is not None: elif metadata.prefill is not None:
forward_fn = lambda: impl._forward_prefill( forward_fn = lambda: impl.forward_mha(
prefill_inputs["q"], prefill_inputs["q"],
prefill_inputs["k_c_normed"], prefill_inputs["k_c_normed"],
prefill_inputs["k_pe"], prefill_inputs["k_pe"],
...@@ -732,6 +807,7 @@ def _run_mla_benchmark_batched( ...@@ -732,6 +807,7 @@ def _run_mla_benchmark_batched(
backend: str, backend: str,
configs_with_params: list[tuple], # [(config, threshold, num_splits), ...] configs_with_params: list[tuple], # [(config, threshold, num_splits), ...]
index_topk: int = 2048, index_topk: int = 2048,
prefill_backend: str | None = None,
) -> list[BenchmarkResult]: ) -> list[BenchmarkResult]:
""" """
Unified batched MLA benchmark runner for all backends. Unified batched MLA benchmark runner for all backends.
...@@ -743,11 +819,13 @@ def _run_mla_benchmark_batched( ...@@ -743,11 +819,13 @@ def _run_mla_benchmark_batched(
to avoid setup/teardown overhead. to avoid setup/teardown overhead.
Args: Args:
backend: Backend name backend: Backend name (decode backend used for impl construction)
configs_with_params: List of (config, threshold, num_splits) tuples configs_with_params: List of (config, threshold, num_splits) tuples
- threshold: reorder_batch_threshold (FlashAttn/FlashMLA only) - threshold: reorder_batch_threshold (FlashAttn/FlashMLA only)
- num_splits: num_kv_splits (CUTLASS only) - num_splits: num_kv_splits (CUTLASS only)
index_topk: Topk value for sparse MLA backends (default 2048) index_topk: Topk value for sparse MLA backends (default 2048)
prefill_backend: Prefill backend name (e.g., "fa3", "fa4").
When set, forces the specified FlashAttention version for prefill.
Returns: Returns:
List of BenchmarkResult objects List of BenchmarkResult objects
...@@ -780,11 +858,25 @@ def _run_mla_benchmark_batched( ...@@ -780,11 +858,25 @@ def _run_mla_benchmark_batched(
block_size=block_size, block_size=block_size,
mla_dims=mla_dims, # Use custom dims from config or default mla_dims=mla_dims, # Use custom dims from config or default
index_topk=index_topk if is_sparse else None, index_topk=index_topk if is_sparse else None,
prefill_backend=prefill_backend,
) )
results = [] results = []
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
# Clear cached prefill backend detection functions so they re-evaluate
# with the current VllmConfig. These are @functools.cache decorated and
# would otherwise return stale results from a previous backend's config.
from vllm.model_executor.layers.attention.mla_attention import (
use_cudnn_prefill,
use_flashinfer_prefill,
use_trtllm_ragged_deepseek_prefill,
)
use_flashinfer_prefill.cache_clear()
use_cudnn_prefill.cache_clear()
use_trtllm_ragged_deepseek_prefill.cache_clear()
# Create backend impl, layer, builder, and indexer (reused across benchmarks) # Create backend impl, layer, builder, and indexer (reused across benchmarks)
impl, layer, builder_instance, indexer = _create_backend_impl( impl, layer, builder_instance, indexer = _create_backend_impl(
backend_cfg, backend_cfg,
...@@ -794,6 +886,38 @@ def _run_mla_benchmark_batched( ...@@ -794,6 +886,38 @@ def _run_mla_benchmark_batched(
index_topk=index_topk if is_sparse else None, index_topk=index_topk if is_sparse else None,
) )
# Verify the actual prefill backend matches what was requested
if prefill_backend is not None:
prefill_cfg = get_prefill_backend_config(prefill_backend)
fa_version = prefill_cfg["flash_attn_version"]
if fa_version is not None:
# FA backend: verify the impl's FA version
actual_fa_version = getattr(impl, "vllm_flash_attn_version", None)
if actual_fa_version != fa_version:
raise RuntimeError(
f"Prefill backend '{prefill_backend}' requested FA "
f"version {fa_version}, but the impl is using FA "
f"version {actual_fa_version}. Check "
f"vllm/v1/attention/backends/fa_utils.py."
)
else:
# Non-FA backend: verify the builder picked the right path
expected_flags = {
"flashinfer": "_use_fi_prefill",
"cudnn": "_use_cudnn_prefill",
"trtllm": "_use_trtllm_ragged_prefill",
}
flag_name = expected_flags.get(prefill_backend)
if flag_name and not getattr(builder_instance, flag_name, False):
raise RuntimeError(
f"Prefill backend '{prefill_backend}' was requested "
f"but the metadata builder did not enable it. This "
f"usually means a dependency is missing (e.g., "
f"flashinfer not installed) or the platform doesn't "
f"support it."
)
# Run each benchmark with the shared impl # Run each benchmark with the shared impl
for config, threshold, num_splits in configs_with_params: for config, threshold, num_splits in configs_with_params:
# Set threshold for this benchmark (FlashAttn/FlashMLA only) # Set threshold for this benchmark (FlashAttn/FlashMLA only)
...@@ -844,6 +968,7 @@ def run_mla_benchmark( ...@@ -844,6 +968,7 @@ def run_mla_benchmark(
reorder_batch_threshold: int | None = None, reorder_batch_threshold: int | None = None,
num_kv_splits: int | None = None, num_kv_splits: int | None = None,
index_topk: int = 2048, index_topk: int = 2048,
prefill_backend: str | None = None,
) -> BenchmarkResult | list[BenchmarkResult]: ) -> BenchmarkResult | list[BenchmarkResult]:
""" """
Unified MLA benchmark runner for all backends. Unified MLA benchmark runner for all backends.
...@@ -861,6 +986,8 @@ def run_mla_benchmark( ...@@ -861,6 +986,8 @@ def run_mla_benchmark(
(single config mode only) (single config mode only)
num_kv_splits: Number of KV splits for CUTLASS (single config mode only) num_kv_splits: Number of KV splits for CUTLASS (single config mode only)
index_topk: Topk value for sparse MLA backends (default 2048) index_topk: Topk value for sparse MLA backends (default 2048)
prefill_backend: Prefill backend name (e.g., "fa3", "fa4").
When set, forces the specified FlashAttention version for prefill.
Returns: Returns:
BenchmarkResult (single mode) or list of BenchmarkResult (batched mode) BenchmarkResult (single mode) or list of BenchmarkResult (batched mode)
...@@ -884,7 +1011,9 @@ def run_mla_benchmark( ...@@ -884,7 +1011,9 @@ def run_mla_benchmark(
return_single = True return_single = True
# Use unified batched execution # Use unified batched execution
results = _run_mla_benchmark_batched(backend, configs_with_params, index_topk) results = _run_mla_benchmark_batched(
backend, configs_with_params, index_topk, prefill_backend=prefill_backend
)
# Return single result or list based on input # Return single result or list based on input
return results[0] if return_single else results return results[0] if return_single else results
...@@ -39,7 +39,7 @@ else() ...@@ -39,7 +39,7 @@ else()
FetchContent_Declare( FetchContent_Declare(
vllm-flash-attn vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 140c00c0241bb60cc6e44e7c1be9998d4b20d8d2 GIT_TAG 1488682bb545f7d020e958a33116b1419d1cfc83
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types # Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
......
...@@ -30,14 +30,14 @@ class AttentionConfig: ...@@ -30,14 +30,14 @@ class AttentionConfig:
use_cudnn_prefill: bool = False use_cudnn_prefill: bool = False
"""Whether to use cudnn prefill.""" """Whether to use cudnn prefill."""
use_trtllm_ragged_deepseek_prefill: bool = True use_trtllm_ragged_deepseek_prefill: bool = False
"""Whether to use TRTLLM ragged deepseek prefill.""" """Whether to use TRTLLM ragged deepseek prefill."""
use_trtllm_attention: bool | None = None use_trtllm_attention: bool | None = None
"""If set to True/False, use or don't use the TRTLLM attention backend """If set to True/False, use or don't use the TRTLLM attention backend
in flashinfer. If None, auto-detect the attention backend in flashinfer.""" in flashinfer. If None, auto-detect the attention backend in flashinfer."""
disable_flashinfer_prefill: bool = False disable_flashinfer_prefill: bool = True
"""Whether to disable flashinfer prefill.""" """Whether to disable flashinfer prefill."""
disable_flashinfer_q_quantization: bool = False disable_flashinfer_q_quantization: bool = False
......
...@@ -1282,8 +1282,6 @@ def is_deepseek_r1_mla_compatible(vllm_config: VllmConfig) -> bool: ...@@ -1282,8 +1282,6 @@ def is_deepseek_r1_mla_compatible(vllm_config: VllmConfig) -> bool:
@functools.cache @functools.cache
def use_flashinfer_prefill() -> bool: def use_flashinfer_prefill() -> bool:
# For blackwell default to flashinfer prefill if it's available since
# it is faster than FA2.
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
...@@ -2154,14 +2152,17 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -2154,14 +2152,17 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# For MLA the v head dim is smaller than qk head dim so we pad out # For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do # v with 0s to match the qk head dim for attention backends that do
# not support different headdims # not support different headdims.
# We don't need to pad V if we are on a hopper system with FA3 # FA3 on Hopper (SM90) and FA4 natively handle diff headdims.
device_capability = current_platform.get_device_capability() device_capability = current_platform.get_device_capability()
self._pad_v = self.vllm_flash_attn_version is None or not ( self._pad_v = self.vllm_flash_attn_version is None or not (
(
self.vllm_flash_attn_version == 3 self.vllm_flash_attn_version == 3
and device_capability is not None and device_capability is not None
and device_capability[0] == 9 and device_capability[0] == 9
) )
or self.vllm_flash_attn_version == 4
)
self.dcp_world_size: int = -1 self.dcp_world_size: int = -1
......
...@@ -125,11 +125,14 @@ def get_flash_attn_version( ...@@ -125,11 +125,14 @@ def get_flash_attn_version(
# FA4 on SM100 (Blackwell) has TMEM capacity limits that restrict # FA4 on SM100 (Blackwell) has TMEM capacity limits that restrict
# supported head dimensions. # supported head dimensions.
# See: https://github.com/Dao-AILab/flash-attention/issues/1959 # See: https://github.com/Dao-AILab/flash-attention/issues/1959
# Exception: hdim 192 is supported for MLA's diff-headdim case
# (qk=192, v=128), added upstream in commits 1a15733e/1b36ab19.
if ( if (
fa_version == 4 fa_version == 4
and device_capability.major >= 10 and device_capability.major >= 10
and head_size is not None and head_size is not None
and head_size > 128 and head_size > 128
and head_size != 192
): ):
logger.warning_once( logger.warning_once(
"FA4 on Blackwell does not support head_size=%d due to TMEM " "FA4 on Blackwell does not support head_size=%d due to TMEM "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment