Unverified Commit a3a51d20 authored by Wei Zhao's avatar Wei Zhao Committed by GitHub
Browse files

[Benchmark] Improvements to attention benchmark script (#37115)


Signed-off-by: default avatarwzhao18 <wzhao18.sz@gmail.com>
parent e5b80760
...@@ -47,6 +47,8 @@ from common import ( ...@@ -47,6 +47,8 @@ from common import (
is_mla_backend, is_mla_backend,
) )
from vllm.v1.worker.workspace import init_workspace_manager
def run_standard_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult: def run_standard_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
"""Run standard attention benchmark (Flash/Triton/FlashInfer).""" """Run standard attention benchmark (Flash/Triton/FlashInfer)."""
...@@ -462,7 +464,7 @@ def main(): ...@@ -462,7 +464,7 @@ def main():
parser.add_argument( parser.add_argument(
"--batch-specs", "--batch-specs",
nargs="+", nargs="+",
default=["q2k", "8q1s1k"], default=None,
help="Batch specifications using extended grammar", help="Batch specifications using extended grammar",
) )
...@@ -478,6 +480,21 @@ def main(): ...@@ -478,6 +480,21 @@ def main():
parser.add_argument("--repeats", type=int, default=1, help="Repetitions") parser.add_argument("--repeats", type=int, default=1, help="Repetitions")
parser.add_argument("--warmup-iters", type=int, default=3, help="Warmup iterations") parser.add_argument("--warmup-iters", type=int, default=3, help="Warmup iterations")
parser.add_argument("--profile-memory", action="store_true", help="Profile memory") parser.add_argument("--profile-memory", action="store_true", help="Profile memory")
parser.add_argument(
"--kv-cache-dtype",
default="auto",
choices=["auto", "fp8"],
help="KV cache dtype: auto or fp8",
)
parser.add_argument(
"--cuda-graphs",
action=argparse.BooleanOptionalAction,
default=True,
help=(
"Launch kernels with CUDA graphs to eliminate CPU overhead"
"in measurements (default: True)"
),
)
# Parameter sweep (use YAML config for advanced sweeps) # Parameter sweep (use YAML config for advanced sweeps)
parser.add_argument( parser.add_argument(
...@@ -536,21 +553,24 @@ def main(): ...@@ -536,21 +553,24 @@ def main():
# Batch specs and sizes # Batch specs and sizes
# Support both explicit batch_specs and generated batch_spec_ranges # Support both explicit batch_specs and generated batch_spec_ranges
if "batch_spec_ranges" in yaml_config: # CLI --batch-specs takes precedence over YAML when provided.
# Generate batch specs from ranges cli_batch_specs_provided = args.batch_specs is not None
generated_specs = generate_batch_specs_from_ranges( if not cli_batch_specs_provided:
yaml_config["batch_spec_ranges"] if "batch_spec_ranges" in yaml_config:
) # Generate batch specs from ranges
# Combine with any explicit batch_specs generated_specs = generate_batch_specs_from_ranges(
if "batch_specs" in yaml_config: yaml_config["batch_spec_ranges"]
args.batch_specs = yaml_config["batch_specs"] + generated_specs )
else: # Combine with any explicit batch_specs
args.batch_specs = generated_specs if "batch_specs" in yaml_config:
console.print( args.batch_specs = yaml_config["batch_specs"] + generated_specs
f"[dim]Generated {len(generated_specs)} batch specs from ranges[/]" else:
) args.batch_specs = generated_specs
elif "batch_specs" in yaml_config: console.print(
args.batch_specs = yaml_config["batch_specs"] f"[dim]Generated {len(generated_specs)} batch specs from ranges[/]"
)
elif "batch_specs" in yaml_config:
args.batch_specs = yaml_config["batch_specs"]
if "batch_sizes" in yaml_config: if "batch_sizes" in yaml_config:
args.batch_sizes = yaml_config["batch_sizes"] args.batch_sizes = yaml_config["batch_sizes"]
...@@ -575,6 +595,10 @@ def main(): ...@@ -575,6 +595,10 @@ def main():
args.warmup_iters = yaml_config["warmup_iters"] args.warmup_iters = yaml_config["warmup_iters"]
if "profile_memory" in yaml_config: if "profile_memory" in yaml_config:
args.profile_memory = yaml_config["profile_memory"] args.profile_memory = yaml_config["profile_memory"]
if "kv_cache_dtype" in yaml_config:
args.kv_cache_dtype = yaml_config["kv_cache_dtype"]
if "cuda_graphs" in yaml_config:
args.cuda_graphs = yaml_config["cuda_graphs"]
# Parameter sweep configuration # Parameter sweep configuration
if "parameter_sweep" in yaml_config: if "parameter_sweep" in yaml_config:
...@@ -629,12 +653,18 @@ def main(): ...@@ -629,12 +653,18 @@ def main():
# Determine backends # Determine backends
backends = args.backends or ([args.backend] if args.backend else ["flash"]) backends = args.backends or ([args.backend] if args.backend else ["flash"])
prefill_backends = getattr(args, "prefill_backends", None) prefill_backends = getattr(args, "prefill_backends", None)
if not args.batch_specs:
args.batch_specs = ["q2k", "8q1s1k"]
console.print(f"Backends: {', '.join(backends)}") console.print(f"Backends: {', '.join(backends)}")
if prefill_backends: if prefill_backends:
console.print(f"Prefill backends: {', '.join(prefill_backends)}") console.print(f"Prefill backends: {', '.join(prefill_backends)}")
console.print(f"Batch specs: {', '.join(args.batch_specs)}") console.print(f"Batch specs: {', '.join(args.batch_specs)}")
console.print(f"KV cache dtype: {args.kv_cache_dtype}")
console.print(f"CUDA graphs: {args.cuda_graphs}")
console.print() console.print()
init_workspace_manager(args.device)
# Run benchmarks # Run benchmarks
all_results = [] all_results = []
...@@ -687,6 +717,8 @@ def main(): ...@@ -687,6 +717,8 @@ def main():
repeats=args.repeats, repeats=args.repeats,
warmup_iters=args.warmup_iters, warmup_iters=args.warmup_iters,
profile_memory=args.profile_memory, profile_memory=args.profile_memory,
kv_cache_dtype=args.kv_cache_dtype,
use_cuda_graphs=args.cuda_graphs,
) )
# Add decode pipeline config # Add decode pipeline config
...@@ -839,6 +871,8 @@ def main(): ...@@ -839,6 +871,8 @@ def main():
"repeats": args.repeats, "repeats": args.repeats,
"warmup_iters": args.warmup_iters, "warmup_iters": args.warmup_iters,
"profile_memory": args.profile_memory, "profile_memory": args.profile_memory,
"kv_cache_dtype": args.kv_cache_dtype,
"use_cuda_graphs": args.cuda_graphs,
} }
all_results = run_model_parameter_sweep( all_results = run_model_parameter_sweep(
backends, backends,
...@@ -861,6 +895,8 @@ def main(): ...@@ -861,6 +895,8 @@ def main():
"repeats": args.repeats, "repeats": args.repeats,
"warmup_iters": args.warmup_iters, "warmup_iters": args.warmup_iters,
"profile_memory": args.profile_memory, "profile_memory": args.profile_memory,
"kv_cache_dtype": args.kv_cache_dtype,
"use_cuda_graphs": args.cuda_graphs,
} }
all_results = run_parameter_sweep( all_results = run_parameter_sweep(
backends, args.batch_specs, base_config_args, args.parameter_sweep, console backends, args.batch_specs, base_config_args, args.parameter_sweep, console
...@@ -891,6 +927,8 @@ def main(): ...@@ -891,6 +927,8 @@ def main():
repeats=args.repeats, repeats=args.repeats,
warmup_iters=args.warmup_iters, warmup_iters=args.warmup_iters,
profile_memory=args.profile_memory, profile_memory=args.profile_memory,
kv_cache_dtype=args.kv_cache_dtype,
use_cuda_graphs=args.cuda_graphs,
) )
result = run_benchmark(config) result = run_benchmark(config)
......
...@@ -213,6 +213,9 @@ class BenchmarkConfig: ...@@ -213,6 +213,9 @@ class BenchmarkConfig:
profile_memory: bool = False profile_memory: bool = False
use_cuda_graphs: bool = False use_cuda_graphs: bool = False
# "auto" or "fp8"
kv_cache_dtype: str = "auto"
# MLA-specific # MLA-specific
prefill_backend: str | None = None prefill_backend: str | None = None
kv_lora_rank: int | None = None kv_lora_rank: int | None = None
...@@ -369,6 +372,7 @@ class ResultsFormatter: ...@@ -369,6 +372,7 @@ class ResultsFormatter:
"backend", "backend",
"batch_spec", "batch_spec",
"num_layers", "num_layers",
"kv_cache_dtype",
"mean_time", "mean_time",
"std_time", "std_time",
"throughput", "throughput",
...@@ -382,6 +386,7 @@ class ResultsFormatter: ...@@ -382,6 +386,7 @@ class ResultsFormatter:
"backend": r.config.backend, "backend": r.config.backend,
"batch_spec": r.config.batch_spec, "batch_spec": r.config.batch_spec,
"num_layers": r.config.num_layers, "num_layers": r.config.num_layers,
"kv_cache_dtype": r.config.kv_cache_dtype,
"mean_time": r.mean_time, "mean_time": r.mean_time,
"std_time": r.std_time, "std_time": r.std_time,
"throughput": r.throughput_tokens_per_sec or 0, "throughput": r.throughput_tokens_per_sec or 0,
......
...@@ -30,9 +30,9 @@ batch_specs: ...@@ -30,9 +30,9 @@ batch_specs:
- "2q16k_32q1s4k" # 2 very large prefill + 32 decode - "2q16k_32q1s4k" # 2 very large prefill + 32 decode
# Context extension + decode # Context extension + decode
- "2q1kkv2k_16q1s1k" # 2 extend + 16 decode - "2q1ks2k_16q1s1k" # 2 extend + 16 decode
- "4q2kkv4k_32q1s2k" # 4 extend + 32 decode - "4q2ks4k_32q1s2k" # 4 extend + 32 decode
- "2q1kkv8k_32q1s2k" # 2 large extend + 32 decode - "2q1ks8k_32q1s2k" # 2 large extend + 32 decode
# Explicitly chunked prefill # Explicitly chunked prefill
- "q8k" # 8k prefill with chunking hint - "q8k" # 8k prefill with chunking hint
......
# MLA decode-only benchmark configuration
model:
name: "deepseek-v3"
num_layers: 60
num_q_heads: 128 # Base value, can be swept for TP simulation
num_kv_heads: 1 # MLA uses single latent KV
head_dim: 576
kv_lora_rank: 512
qk_nope_head_dim: 128
qk_rope_head_dim: 64
v_head_dim: 128
block_size: 128 # CUTLASS MLA and FlashAttn MLA use 128
# Model parameter sweep: simulate tensor parallelism by varying num_q_heads
# TP=1: 128 heads, TP=2: 64 heads, TP=4: 32 heads, TP=8: 16 heads
model_parameter_sweep:
param_name: "num_q_heads"
values: [128, 64, 32, 16]
label_format: "{backend}_{value}h"
batch_specs:
# Small batches, varying sequence lengths
- "16q1s512" # 16 requests, 512 KV cache
- "16q1s1k" # 16 requests, 1k KV cache
- "16q1s2k" # 16 requests, 2k KV cache
- "16q1s4k" # 16 requests, 4k KV cache
# Medium batches
- "32q1s1k" # 32 requests, 1k KV cache
- "32q1s2k" # 32 requests, 2k KV cache
- "32q1s4k" # 32 requests, 4k KV cache
- "32q1s8k" # 32 requests, 8k KV cache
# Large batches
- "64q1s1k" # 64 requests, 1k KV cache
- "64q1s2k" # 64 requests, 2k KV cache
- "64q1s4k" # 64 requests, 4k KV cache
- "64q1s8k" # 64 requests, 8k KV cache
# Very large batches
- "128q1s1k" # 128 requests, 1k KV cache
- "128q1s2k" # 128 requests, 2k KV cache
- "128q1s4k" # 128 requests, 4k KV cache
- "128q1s8k" # 128 requests, 8k KV cache
# Long context
- "32q1s16k" # 32 requests, 16k KV cache
- "32q1s32k" # 32 requests, 32k KV cache
backends:
- FLASHMLA_SPARSE
- FLASHINFER_MLA_SPARSE
device: "cuda:0"
repeats: 100
warmup_iters: 10
profile_memory: true
...@@ -60,9 +60,11 @@ def create_minimal_vllm_config( ...@@ -60,9 +60,11 @@ def create_minimal_vllm_config(
model_name: str = "deepseek-v3", model_name: str = "deepseek-v3",
block_size: int = 128, block_size: int = 128,
max_num_seqs: int = 256, max_num_seqs: int = 256,
max_num_batched_tokens: int = 8192,
mla_dims: dict | None = None, mla_dims: dict | None = None,
index_topk: int | None = None, index_topk: int | None = None,
prefill_backend: str | None = None, prefill_backend: str | None = None,
kv_cache_dtype: str = "auto",
) -> VllmConfig: ) -> VllmConfig:
""" """
Create minimal VllmConfig for MLA benchmarks. Create minimal VllmConfig for MLA benchmarks.
...@@ -149,13 +151,13 @@ def create_minimal_vllm_config( ...@@ -149,13 +151,13 @@ def create_minimal_vllm_config(
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=block_size, block_size=block_size,
gpu_memory_utilization=0.9, gpu_memory_utilization=0.9,
cache_dtype="auto", cache_dtype=kv_cache_dtype,
enable_prefix_caching=False, enable_prefix_caching=False,
) )
scheduler_config = SchedulerConfig( scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
max_num_batched_tokens=8192, max_num_batched_tokens=max(max_num_batched_tokens, max_num_seqs),
max_model_len=32768, max_model_len=32768,
is_encoder_decoder=False, is_encoder_decoder=False,
enable_chunked_prefill=True, enable_chunked_prefill=True,
...@@ -535,6 +537,7 @@ def _create_backend_impl( ...@@ -535,6 +537,7 @@ def _create_backend_impl(
device: torch.device, device: torch.device,
max_num_tokens: int = 8192, max_num_tokens: int = 8192,
index_topk: int | None = None, index_topk: int | None = None,
kv_cache_dtype: str = "auto",
): ):
""" """
Create backend implementation instance. Create backend implementation instance.
...@@ -583,7 +586,7 @@ def _create_backend_impl( ...@@ -583,7 +586,7 @@ def _create_backend_impl(
"num_kv_heads": mla_dims["num_kv_heads"], "num_kv_heads": mla_dims["num_kv_heads"],
"alibi_slopes": None, "alibi_slopes": None,
"sliding_window": None, "sliding_window": None,
"kv_cache_dtype": "auto", "kv_cache_dtype": kv_cache_dtype,
"logits_soft_cap": None, "logits_soft_cap": None,
"attn_type": "decoder", "attn_type": "decoder",
"kv_sharing_target_layer_name": None, "kv_sharing_target_layer_name": None,
...@@ -701,6 +704,7 @@ def _run_single_benchmark( ...@@ -701,6 +704,7 @@ def _run_single_benchmark(
mla_dims: dict, mla_dims: dict,
device: torch.device, device: torch.device,
indexer=None, indexer=None,
kv_cache_dtype: str | None = None,
) -> BenchmarkResult: ) -> BenchmarkResult:
""" """
Run a single benchmark iteration. Run a single benchmark iteration.
...@@ -734,49 +738,124 @@ def _run_single_benchmark( ...@@ -734,49 +738,124 @@ def _run_single_benchmark(
) )
# Create KV cache # Create KV cache
kv_cache = torch.zeros( if kv_cache_dtype is None:
num_blocks, kv_cache_dtype = getattr(config, "kv_cache_dtype", "auto")
block_size, head_size = mla_dims["kv_lora_rank"] + mla_dims["qk_rope_head_dim"]
mla_dims["kv_lora_rank"] + mla_dims["qk_rope_head_dim"], if kv_cache_dtype == "fp8_ds_mla":
device=device, # FlashMLA sparse custom format: 656 bytes per token, stored as uint8.
dtype=torch.bfloat16, # Layout: kv_lora_rank fp8 bytes + 4 float32 tile scales
) # + 2*rope_dim bf16 bytes
# = 512 + 16 + 128 = 656 bytes for DeepSeek dims.
kv_cache = torch.zeros(
num_blocks,
block_size,
656,
device=device,
dtype=torch.uint8,
)
elif kv_cache_dtype == "fp8":
from vllm.platforms import current_platform
# Create input tensors for both decode and prefill modes kv_cache = torch.zeros(
decode_inputs, prefill_inputs = _create_input_tensors( num_blocks,
total_q, block_size,
mla_dims, head_size,
backend_cfg["query_format"], device=device,
device, dtype=torch.uint8,
torch.bfloat16, ).view(current_platform.fp8_dtype())
) else:
kv_cache = torch.zeros(
num_blocks,
block_size,
head_size,
device=device,
dtype=torch.bfloat16,
)
# Fill indexer with random indices for sparse backends # Fill indexer with random indices for sparse backends
is_sparse = backend_cfg.get("is_sparse", False) is_sparse = backend_cfg.get("is_sparse", False)
if is_sparse and indexer is not None: if is_sparse and indexer is not None:
indexer.fill_random_indices(total_q, max_kv_len) indexer.fill_random_indices(total_q, max_kv_len)
# Determine which forward method to use based on metadata # Determine which forward methods to use based on metadata.
if metadata.decode is not None: # Sparse MLA backends always use forward_mqa
forward_fn = lambda: impl.forward_mqa(decode_inputs, kv_cache, metadata, layer) has_decode = is_sparse or getattr(metadata, "decode", None) is not None
elif metadata.prefill is not None: has_prefill = not is_sparse and getattr(metadata, "prefill", None) is not None
forward_fn = lambda: impl.forward_mha( if not has_decode and not has_prefill:
prefill_inputs["q"],
prefill_inputs["k_c_normed"],
prefill_inputs["k_pe"],
kv_cache,
metadata,
prefill_inputs["k_scale"],
prefill_inputs["output"],
)
else:
raise RuntimeError("Metadata has neither decode nor prefill metadata") raise RuntimeError("Metadata has neither decode nor prefill metadata")
num_decode = (
metadata.num_decode_tokens
if (has_decode and has_prefill)
else total_q
if has_decode
else 0
)
num_prefill = total_q - num_decode
# Some backends requires fp8 queries when using fp8 KV cache.
is_fp8_kvcache = kv_cache_dtype.startswith("fp8")
quantize_query = is_fp8_kvcache and getattr(
impl, "supports_quant_query_input", False
)
# quantize_query forces concat format
query_fmt = "concat" if quantize_query else backend_cfg["query_format"]
# Create decode query tensors
if has_decode:
decode_inputs, _ = _create_input_tensors(
num_decode, mla_dims, query_fmt, device, torch.bfloat16
)
# Cast decode query to fp8 if the backend supports it
if quantize_query:
from vllm.platforms import current_platform
if isinstance(decode_inputs, tuple):
decode_inputs = torch.cat(list(decode_inputs), dim=-1)
decode_inputs = decode_inputs.to(current_platform.fp8_dtype())
# Create prefill input tensors
if has_prefill:
_, prefill_inputs = _create_input_tensors(
num_prefill, mla_dims, query_fmt, device, torch.bfloat16
)
# Build forward function
def forward_fn():
results = []
if has_decode:
results.append(impl.forward_mqa(decode_inputs, kv_cache, metadata, layer))
if has_prefill:
results.append(
impl.forward_mha(
prefill_inputs["q"],
prefill_inputs["k_c_normed"],
prefill_inputs["k_pe"],
kv_cache,
metadata,
prefill_inputs["k_scale"],
prefill_inputs["output"],
)
)
return results[0] if len(results) == 1 else tuple(results)
# Warmup # Warmup
for _ in range(config.warmup_iters): for _ in range(config.warmup_iters):
forward_fn() forward_fn()
torch.accelerator.synchronize() torch.accelerator.synchronize()
# Optionally capture a CUDA graph after warmup.
# Graph replay eliminates CPU launch overhead so timings reflect pure
# kernel time.
if config.use_cuda_graphs:
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
forward_fn()
benchmark_fn = graph.replay
else:
benchmark_fn = forward_fn
# Benchmark # Benchmark
times = [] times = []
for _ in range(config.repeats): for _ in range(config.repeats):
...@@ -785,7 +864,7 @@ def _run_single_benchmark( ...@@ -785,7 +864,7 @@ def _run_single_benchmark(
start.record() start.record()
for _ in range(config.num_layers): for _ in range(config.num_layers):
forward_fn() benchmark_fn()
end.record() end.record()
torch.accelerator.synchronize() torch.accelerator.synchronize()
...@@ -852,13 +931,30 @@ def _run_mla_benchmark_batched( ...@@ -852,13 +931,30 @@ def _run_mla_benchmark_batched(
# Determine if this is a sparse backend # Determine if this is a sparse backend
is_sparse = backend_cfg.get("is_sparse", False) is_sparse = backend_cfg.get("is_sparse", False)
# Extract kv_cache_dtype from the first config
kv_cache_dtype = getattr(first_config, "kv_cache_dtype", "auto")
# FlashMLA sparse only supports "fp8_ds_mla" internally (not generic "fp8").
# Remap here so the user can pass --kv-cache-dtype fp8 regardless of backend.
if backend.upper() == "FLASHMLA_SPARSE" and kv_cache_dtype == "fp8":
kv_cache_dtype = "fp8_ds_mla"
# Compute max total_q across all configs so the metadata builder buffer
# and scheduler config are large enough for all batch specs.
max_total_q = max(
sum(r.q_len for r in parse_batch_spec(cfg.batch_spec))
for cfg, *_ in configs_with_params
)
# Create and set vLLM config for MLA (reused across all benchmarks) # Create and set vLLM config for MLA (reused across all benchmarks)
vllm_config = create_minimal_vllm_config( vllm_config = create_minimal_vllm_config(
model_name="deepseek-v3", # Used only for model path model_name="deepseek-v3", # Used only for model path
block_size=block_size, block_size=block_size,
max_num_batched_tokens=max_total_q,
mla_dims=mla_dims, # Use custom dims from config or default mla_dims=mla_dims, # Use custom dims from config or default
index_topk=index_topk if is_sparse else None, index_topk=index_topk if is_sparse else None,
prefill_backend=prefill_backend, prefill_backend=prefill_backend,
kv_cache_dtype=kv_cache_dtype,
) )
results = [] results = []
...@@ -883,7 +979,9 @@ def _run_mla_benchmark_batched( ...@@ -883,7 +979,9 @@ def _run_mla_benchmark_batched(
mla_dims, mla_dims,
vllm_config, vllm_config,
device, device,
max_num_tokens=max_total_q,
index_topk=index_topk if is_sparse else None, index_topk=index_topk if is_sparse else None,
kv_cache_dtype=kv_cache_dtype,
) )
# Verify the actual prefill backend matches what was requested # Verify the actual prefill backend matches what was requested
...@@ -942,6 +1040,7 @@ def _run_mla_benchmark_batched( ...@@ -942,6 +1040,7 @@ def _run_mla_benchmark_batched(
mla_dims, mla_dims,
device, device,
indexer=indexer, indexer=indexer,
kv_cache_dtype=kv_cache_dtype,
) )
results.append(result) results.append(result)
......
...@@ -140,7 +140,7 @@ def _create_vllm_config( ...@@ -140,7 +140,7 @@ def _create_vllm_config(
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=config.block_size, block_size=config.block_size,
cache_dtype="auto", cache_dtype=config.kv_cache_dtype,
) )
cache_config.num_gpu_blocks = max_num_blocks cache_config.num_gpu_blocks = max_num_blocks
cache_config.num_cpu_blocks = 0 cache_config.num_cpu_blocks = 0
...@@ -215,7 +215,7 @@ def _create_backend_impl( ...@@ -215,7 +215,7 @@ def _create_backend_impl(
num_kv_heads=config.num_kv_heads, num_kv_heads=config.num_kv_heads,
alibi_slopes=None, alibi_slopes=None,
sliding_window=None, sliding_window=None,
kv_cache_dtype="auto", kv_cache_dtype=config.kv_cache_dtype,
) )
kv_cache_spec = FullAttentionSpec( kv_cache_spec = FullAttentionSpec(
...@@ -288,12 +288,22 @@ def _create_input_tensors( ...@@ -288,12 +288,22 @@ def _create_input_tensors(
total_q: int, total_q: int,
device: torch.device, device: torch.device,
dtype: torch.dtype, dtype: torch.dtype,
quantize_query: bool = False,
) -> tuple: ) -> tuple:
"""Create Q, K, V input tensors for all layers.""" """Create Q, K, V input tensors for all layers.
When quantize_query is True, queries are cast to fp8 to match backends
that require query/key/value dtype consistency.
"""
q_dtype = dtype
if quantize_query:
from vllm.platforms import current_platform
q_dtype = current_platform.fp8_dtype()
q_list = [ q_list = [
torch.randn( torch.randn(
total_q, config.num_q_heads, config.head_dim, device=device, dtype=dtype total_q, config.num_q_heads, config.head_dim, device=device, dtype=dtype
) ).to(q_dtype)
for _ in range(config.num_layers) for _ in range(config.num_layers)
] ]
k_list = [ k_list = [
...@@ -344,10 +354,17 @@ def _create_kv_cache( ...@@ -344,10 +354,17 @@ def _create_kv_cache(
# Compute inverse permutation to get back to logical view # Compute inverse permutation to get back to logical view
inv_order = [stride_order.index(i) for i in range(len(stride_order))] inv_order = [stride_order.index(i) for i in range(len(stride_order))]
# Use fp8 dtype for cache when requested.
cache_dtype = dtype
if config.kv_cache_dtype == "fp8":
from vllm.platforms import current_platform
cache_dtype = current_platform.fp8_dtype()
cache_list = [] cache_list = []
for _ in range(config.num_layers): for _ in range(config.num_layers):
# Allocate in physical layout order (contiguous in memory) # Allocate in physical layout order (contiguous in memory)
cache = torch.zeros(*physical_shape, device=device, dtype=dtype) cache = torch.zeros(*physical_shape, device=device, dtype=cache_dtype)
# Permute to logical view # Permute to logical view
cache = cache.permute(*inv_order) cache = cache.permute(*inv_order)
cache_list.append(cache) cache_list.append(cache)
...@@ -392,6 +409,37 @@ def _run_single_benchmark( ...@@ -392,6 +409,37 @@ def _run_single_benchmark(
) )
torch.accelerator.synchronize() torch.accelerator.synchronize()
# Optionally capture a CUDA graph after warmup.
# Graph replay eliminates CPU launch overhead so timings reflect pure
# kernel time.
if config.use_cuda_graphs:
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
for i in range(config.num_layers):
impl.forward(
layer,
q_list[i],
k_list[i],
v_list[i],
cache_list[i],
attn_metadata,
output=out,
)
benchmark_fn = graph.replay
else:
def benchmark_fn():
for i in range(config.num_layers):
impl.forward(
layer,
q_list[i],
k_list[i],
v_list[i],
cache_list[i],
attn_metadata,
output=out,
)
# Benchmark # Benchmark
times = [] times = []
for _ in range(config.repeats): for _ in range(config.repeats):
...@@ -399,16 +447,7 @@ def _run_single_benchmark( ...@@ -399,16 +447,7 @@ def _run_single_benchmark(
end = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True)
start.record() start.record()
for i in range(config.num_layers): benchmark_fn()
impl.forward(
layer,
q_list[i],
k_list[i],
v_list[i],
cache_list[i],
attn_metadata,
output=out,
)
end.record() end.record()
torch.accelerator.synchronize() torch.accelerator.synchronize()
...@@ -502,8 +541,12 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult: ...@@ -502,8 +541,12 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
common_attn_metadata=common_metadata, common_attn_metadata=common_metadata,
) )
# Only quantize queries when the impl supports it
quantize_query = config.kv_cache_dtype.startswith("fp8") and getattr(
impl, "supports_quant_query_input", False
)
q_list, k_list, v_list = _create_input_tensors( q_list, k_list, v_list = _create_input_tensors(
config, total_q, device, dtype config, total_q, device, dtype, quantize_query=quantize_query
) )
cache_list = _create_kv_cache( cache_list = _create_kv_cache(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment