Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a3a51d20
Unverified
Commit
a3a51d20
authored
Mar 16, 2026
by
Wei Zhao
Committed by
GitHub
Mar 16, 2026
Browse files
[Benchmark] Improvements to attention benchmark script (#37115)
Signed-off-by:
wzhao18
<
wzhao18.sz@gmail.com
>
parent
e5b80760
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
311 additions
and
68 deletions
+311
-68
benchmarks/attention_benchmarks/benchmark.py
benchmarks/attention_benchmarks/benchmark.py
+54
-16
benchmarks/attention_benchmarks/common.py
benchmarks/attention_benchmarks/common.py
+5
-0
benchmarks/attention_benchmarks/configs/mla_mixed_batch.yaml
benchmarks/attention_benchmarks/configs/mla_mixed_batch.yaml
+3
-3
benchmarks/attention_benchmarks/configs/mla_sparse_decode.yaml
...marks/attention_benchmarks/configs/mla_sparse_decode.yaml
+58
-0
benchmarks/attention_benchmarks/mla_runner.py
benchmarks/attention_benchmarks/mla_runner.py
+132
-33
benchmarks/attention_benchmarks/runner.py
benchmarks/attention_benchmarks/runner.py
+59
-16
No files found.
benchmarks/attention_benchmarks/benchmark.py
View file @
a3a51d20
...
...
@@ -47,6 +47,8 @@ from common import (
is_mla_backend
,
)
from
vllm.v1.worker.workspace
import
init_workspace_manager
def
run_standard_attention_benchmark
(
config
:
BenchmarkConfig
)
->
BenchmarkResult
:
"""Run standard attention benchmark (Flash/Triton/FlashInfer)."""
...
...
@@ -462,7 +464,7 @@ def main():
parser
.
add_argument
(
"--batch-specs"
,
nargs
=
"+"
,
default
=
[
"q2k"
,
"8q1s1k"
]
,
default
=
None
,
help
=
"Batch specifications using extended grammar"
,
)
...
...
@@ -478,6 +480,21 @@ def main():
parser
.
add_argument
(
"--repeats"
,
type
=
int
,
default
=
1
,
help
=
"Repetitions"
)
parser
.
add_argument
(
"--warmup-iters"
,
type
=
int
,
default
=
3
,
help
=
"Warmup iterations"
)
parser
.
add_argument
(
"--profile-memory"
,
action
=
"store_true"
,
help
=
"Profile memory"
)
parser
.
add_argument
(
"--kv-cache-dtype"
,
default
=
"auto"
,
choices
=
[
"auto"
,
"fp8"
],
help
=
"KV cache dtype: auto or fp8"
,
)
parser
.
add_argument
(
"--cuda-graphs"
,
action
=
argparse
.
BooleanOptionalAction
,
default
=
True
,
help
=
(
"Launch kernels with CUDA graphs to eliminate CPU overhead"
"in measurements (default: True)"
),
)
# Parameter sweep (use YAML config for advanced sweeps)
parser
.
add_argument
(
...
...
@@ -536,21 +553,24 @@ def main():
# Batch specs and sizes
# Support both explicit batch_specs and generated batch_spec_ranges
if
"batch_spec_ranges"
in
yaml_config
:
# Generate batch specs from ranges
generated_specs
=
generate_batch_specs_from_ranges
(
yaml_config
[
"batch_spec_ranges"
]
)
# Combine with any explicit batch_specs
if
"batch_specs"
in
yaml_config
:
args
.
batch_specs
=
yaml_config
[
"batch_specs"
]
+
generated_specs
else
:
args
.
batch_specs
=
generated_specs
console
.
print
(
f
"[dim]Generated
{
len
(
generated_specs
)
}
batch specs from ranges[/]"
)
elif
"batch_specs"
in
yaml_config
:
args
.
batch_specs
=
yaml_config
[
"batch_specs"
]
# CLI --batch-specs takes precedence over YAML when provided.
cli_batch_specs_provided
=
args
.
batch_specs
is
not
None
if
not
cli_batch_specs_provided
:
if
"batch_spec_ranges"
in
yaml_config
:
# Generate batch specs from ranges
generated_specs
=
generate_batch_specs_from_ranges
(
yaml_config
[
"batch_spec_ranges"
]
)
# Combine with any explicit batch_specs
if
"batch_specs"
in
yaml_config
:
args
.
batch_specs
=
yaml_config
[
"batch_specs"
]
+
generated_specs
else
:
args
.
batch_specs
=
generated_specs
console
.
print
(
f
"[dim]Generated
{
len
(
generated_specs
)
}
batch specs from ranges[/]"
)
elif
"batch_specs"
in
yaml_config
:
args
.
batch_specs
=
yaml_config
[
"batch_specs"
]
if
"batch_sizes"
in
yaml_config
:
args
.
batch_sizes
=
yaml_config
[
"batch_sizes"
]
...
...
@@ -575,6 +595,10 @@ def main():
args
.
warmup_iters
=
yaml_config
[
"warmup_iters"
]
if
"profile_memory"
in
yaml_config
:
args
.
profile_memory
=
yaml_config
[
"profile_memory"
]
if
"kv_cache_dtype"
in
yaml_config
:
args
.
kv_cache_dtype
=
yaml_config
[
"kv_cache_dtype"
]
if
"cuda_graphs"
in
yaml_config
:
args
.
cuda_graphs
=
yaml_config
[
"cuda_graphs"
]
# Parameter sweep configuration
if
"parameter_sweep"
in
yaml_config
:
...
...
@@ -629,12 +653,18 @@ def main():
# Determine backends
backends
=
args
.
backends
or
([
args
.
backend
]
if
args
.
backend
else
[
"flash"
])
prefill_backends
=
getattr
(
args
,
"prefill_backends"
,
None
)
if
not
args
.
batch_specs
:
args
.
batch_specs
=
[
"q2k"
,
"8q1s1k"
]
console
.
print
(
f
"Backends:
{
', '
.
join
(
backends
)
}
"
)
if
prefill_backends
:
console
.
print
(
f
"Prefill backends:
{
', '
.
join
(
prefill_backends
)
}
"
)
console
.
print
(
f
"Batch specs:
{
', '
.
join
(
args
.
batch_specs
)
}
"
)
console
.
print
(
f
"KV cache dtype:
{
args
.
kv_cache_dtype
}
"
)
console
.
print
(
f
"CUDA graphs:
{
args
.
cuda_graphs
}
"
)
console
.
print
()
init_workspace_manager
(
args
.
device
)
# Run benchmarks
all_results
=
[]
...
...
@@ -687,6 +717,8 @@ def main():
repeats
=
args
.
repeats
,
warmup_iters
=
args
.
warmup_iters
,
profile_memory
=
args
.
profile_memory
,
kv_cache_dtype
=
args
.
kv_cache_dtype
,
use_cuda_graphs
=
args
.
cuda_graphs
,
)
# Add decode pipeline config
...
...
@@ -839,6 +871,8 @@ def main():
"repeats"
:
args
.
repeats
,
"warmup_iters"
:
args
.
warmup_iters
,
"profile_memory"
:
args
.
profile_memory
,
"kv_cache_dtype"
:
args
.
kv_cache_dtype
,
"use_cuda_graphs"
:
args
.
cuda_graphs
,
}
all_results
=
run_model_parameter_sweep
(
backends
,
...
...
@@ -861,6 +895,8 @@ def main():
"repeats"
:
args
.
repeats
,
"warmup_iters"
:
args
.
warmup_iters
,
"profile_memory"
:
args
.
profile_memory
,
"kv_cache_dtype"
:
args
.
kv_cache_dtype
,
"use_cuda_graphs"
:
args
.
cuda_graphs
,
}
all_results
=
run_parameter_sweep
(
backends
,
args
.
batch_specs
,
base_config_args
,
args
.
parameter_sweep
,
console
...
...
@@ -891,6 +927,8 @@ def main():
repeats
=
args
.
repeats
,
warmup_iters
=
args
.
warmup_iters
,
profile_memory
=
args
.
profile_memory
,
kv_cache_dtype
=
args
.
kv_cache_dtype
,
use_cuda_graphs
=
args
.
cuda_graphs
,
)
result
=
run_benchmark
(
config
)
...
...
benchmarks/attention_benchmarks/common.py
View file @
a3a51d20
...
...
@@ -213,6 +213,9 @@ class BenchmarkConfig:
profile_memory
:
bool
=
False
use_cuda_graphs
:
bool
=
False
# "auto" or "fp8"
kv_cache_dtype
:
str
=
"auto"
# MLA-specific
prefill_backend
:
str
|
None
=
None
kv_lora_rank
:
int
|
None
=
None
...
...
@@ -369,6 +372,7 @@ class ResultsFormatter:
"backend"
,
"batch_spec"
,
"num_layers"
,
"kv_cache_dtype"
,
"mean_time"
,
"std_time"
,
"throughput"
,
...
...
@@ -382,6 +386,7 @@ class ResultsFormatter:
"backend"
:
r
.
config
.
backend
,
"batch_spec"
:
r
.
config
.
batch_spec
,
"num_layers"
:
r
.
config
.
num_layers
,
"kv_cache_dtype"
:
r
.
config
.
kv_cache_dtype
,
"mean_time"
:
r
.
mean_time
,
"std_time"
:
r
.
std_time
,
"throughput"
:
r
.
throughput_tokens_per_sec
or
0
,
...
...
benchmarks/attention_benchmarks/configs/mla_mixed_batch.yaml
View file @
a3a51d20
...
...
@@ -30,9 +30,9 @@ batch_specs:
-
"
2q16k_32q1s4k"
# 2 very large prefill + 32 decode
# Context extension + decode
-
"
2q1k
kv
2k_16q1s1k"
# 2 extend + 16 decode
-
"
4q2k
kv
4k_32q1s2k"
# 4 extend + 32 decode
-
"
2q1k
kv
8k_32q1s2k"
# 2 large extend + 32 decode
-
"
2q1k
s
2k_16q1s1k"
# 2 extend + 16 decode
-
"
4q2k
s
4k_32q1s2k"
# 4 extend + 32 decode
-
"
2q1k
s
8k_32q1s2k"
# 2 large extend + 32 decode
# Explicitly chunked prefill
-
"
q8k"
# 8k prefill with chunking hint
...
...
benchmarks/attention_benchmarks/configs/mla_sparse_decode.yaml
0 → 100644
View file @
a3a51d20
# MLA decode-only benchmark configuration
model
:
name
:
"
deepseek-v3"
num_layers
:
60
num_q_heads
:
128
# Base value, can be swept for TP simulation
num_kv_heads
:
1
# MLA uses single latent KV
head_dim
:
576
kv_lora_rank
:
512
qk_nope_head_dim
:
128
qk_rope_head_dim
:
64
v_head_dim
:
128
block_size
:
128
# CUTLASS MLA and FlashAttn MLA use 128
# Model parameter sweep: simulate tensor parallelism by varying num_q_heads
# TP=1: 128 heads, TP=2: 64 heads, TP=4: 32 heads, TP=8: 16 heads
model_parameter_sweep
:
param_name
:
"
num_q_heads"
values
:
[
128
,
64
,
32
,
16
]
label_format
:
"
{backend}_{value}h"
batch_specs
:
# Small batches, varying sequence lengths
-
"
16q1s512"
# 16 requests, 512 KV cache
-
"
16q1s1k"
# 16 requests, 1k KV cache
-
"
16q1s2k"
# 16 requests, 2k KV cache
-
"
16q1s4k"
# 16 requests, 4k KV cache
# Medium batches
-
"
32q1s1k"
# 32 requests, 1k KV cache
-
"
32q1s2k"
# 32 requests, 2k KV cache
-
"
32q1s4k"
# 32 requests, 4k KV cache
-
"
32q1s8k"
# 32 requests, 8k KV cache
# Large batches
-
"
64q1s1k"
# 64 requests, 1k KV cache
-
"
64q1s2k"
# 64 requests, 2k KV cache
-
"
64q1s4k"
# 64 requests, 4k KV cache
-
"
64q1s8k"
# 64 requests, 8k KV cache
# Very large batches
-
"
128q1s1k"
# 128 requests, 1k KV cache
-
"
128q1s2k"
# 128 requests, 2k KV cache
-
"
128q1s4k"
# 128 requests, 4k KV cache
-
"
128q1s8k"
# 128 requests, 8k KV cache
# Long context
-
"
32q1s16k"
# 32 requests, 16k KV cache
-
"
32q1s32k"
# 32 requests, 32k KV cache
backends
:
-
FLASHMLA_SPARSE
-
FLASHINFER_MLA_SPARSE
device
:
"
cuda:0"
repeats
:
100
warmup_iters
:
10
profile_memory
:
true
benchmarks/attention_benchmarks/mla_runner.py
View file @
a3a51d20
...
...
@@ -60,9 +60,11 @@ def create_minimal_vllm_config(
model_name
:
str
=
"deepseek-v3"
,
block_size
:
int
=
128
,
max_num_seqs
:
int
=
256
,
max_num_batched_tokens
:
int
=
8192
,
mla_dims
:
dict
|
None
=
None
,
index_topk
:
int
|
None
=
None
,
prefill_backend
:
str
|
None
=
None
,
kv_cache_dtype
:
str
=
"auto"
,
)
->
VllmConfig
:
"""
Create minimal VllmConfig for MLA benchmarks.
...
...
@@ -149,13 +151,13 @@ def create_minimal_vllm_config(
cache_config
=
CacheConfig
(
block_size
=
block_size
,
gpu_memory_utilization
=
0.9
,
cache_dtype
=
"auto"
,
cache_dtype
=
kv_cache_dtype
,
enable_prefix_caching
=
False
,
)
scheduler_config
=
SchedulerConfig
(
max_num_seqs
=
max_num_seqs
,
max_num_batched_tokens
=
8192
,
max_num_batched_tokens
=
max
(
max_num_batched_tokens
,
max_num_seqs
)
,
max_model_len
=
32768
,
is_encoder_decoder
=
False
,
enable_chunked_prefill
=
True
,
...
...
@@ -535,6 +537,7 @@ def _create_backend_impl(
device
:
torch
.
device
,
max_num_tokens
:
int
=
8192
,
index_topk
:
int
|
None
=
None
,
kv_cache_dtype
:
str
=
"auto"
,
):
"""
Create backend implementation instance.
...
...
@@ -583,7 +586,7 @@ def _create_backend_impl(
"num_kv_heads"
:
mla_dims
[
"num_kv_heads"
],
"alibi_slopes"
:
None
,
"sliding_window"
:
None
,
"kv_cache_dtype"
:
"auto"
,
"kv_cache_dtype"
:
kv_cache_dtype
,
"logits_soft_cap"
:
None
,
"attn_type"
:
"decoder"
,
"kv_sharing_target_layer_name"
:
None
,
...
...
@@ -701,6 +704,7 @@ def _run_single_benchmark(
mla_dims
:
dict
,
device
:
torch
.
device
,
indexer
=
None
,
kv_cache_dtype
:
str
|
None
=
None
,
)
->
BenchmarkResult
:
"""
Run a single benchmark iteration.
...
...
@@ -734,49 +738,124 @@ def _run_single_benchmark(
)
# Create KV cache
kv_cache
=
torch
.
zeros
(
num_blocks
,
block_size
,
mla_dims
[
"kv_lora_rank"
]
+
mla_dims
[
"qk_rope_head_dim"
],
device
=
device
,
dtype
=
torch
.
bfloat16
,
)
if
kv_cache_dtype
is
None
:
kv_cache_dtype
=
getattr
(
config
,
"kv_cache_dtype"
,
"auto"
)
head_size
=
mla_dims
[
"kv_lora_rank"
]
+
mla_dims
[
"qk_rope_head_dim"
]
if
kv_cache_dtype
==
"fp8_ds_mla"
:
# FlashMLA sparse custom format: 656 bytes per token, stored as uint8.
# Layout: kv_lora_rank fp8 bytes + 4 float32 tile scales
# + 2*rope_dim bf16 bytes
# = 512 + 16 + 128 = 656 bytes for DeepSeek dims.
kv_cache
=
torch
.
zeros
(
num_blocks
,
block_size
,
656
,
device
=
device
,
dtype
=
torch
.
uint8
,
)
elif
kv_cache_dtype
==
"fp8"
:
from
vllm.platforms
import
current_platform
# Create input tensors for both decode and prefill modes
decode_inputs
,
prefill_inputs
=
_create_input_tensors
(
total_q
,
mla_dims
,
backend_cfg
[
"query_format"
],
device
,
torch
.
bfloat16
,
)
kv_cache
=
torch
.
zeros
(
num_blocks
,
block_size
,
head_size
,
device
=
device
,
dtype
=
torch
.
uint8
,
).
view
(
current_platform
.
fp8_dtype
())
else
:
kv_cache
=
torch
.
zeros
(
num_blocks
,
block_size
,
head_size
,
device
=
device
,
dtype
=
torch
.
bfloat16
,
)
# Fill indexer with random indices for sparse backends
is_sparse
=
backend_cfg
.
get
(
"is_sparse"
,
False
)
if
is_sparse
and
indexer
is
not
None
:
indexer
.
fill_random_indices
(
total_q
,
max_kv_len
)
# Determine which forward method to use based on metadata
if
metadata
.
decode
is
not
None
:
forward_fn
=
lambda
:
impl
.
forward_mqa
(
decode_inputs
,
kv_cache
,
metadata
,
layer
)
elif
metadata
.
prefill
is
not
None
:
forward_fn
=
lambda
:
impl
.
forward_mha
(
prefill_inputs
[
"q"
],
prefill_inputs
[
"k_c_normed"
],
prefill_inputs
[
"k_pe"
],
kv_cache
,
metadata
,
prefill_inputs
[
"k_scale"
],
prefill_inputs
[
"output"
],
)
else
:
# Determine which forward methods to use based on metadata.
# Sparse MLA backends always use forward_mqa
has_decode
=
is_sparse
or
getattr
(
metadata
,
"decode"
,
None
)
is
not
None
has_prefill
=
not
is_sparse
and
getattr
(
metadata
,
"prefill"
,
None
)
is
not
None
if
not
has_decode
and
not
has_prefill
:
raise
RuntimeError
(
"Metadata has neither decode nor prefill metadata"
)
num_decode
=
(
metadata
.
num_decode_tokens
if
(
has_decode
and
has_prefill
)
else
total_q
if
has_decode
else
0
)
num_prefill
=
total_q
-
num_decode
# Some backends requires fp8 queries when using fp8 KV cache.
is_fp8_kvcache
=
kv_cache_dtype
.
startswith
(
"fp8"
)
quantize_query
=
is_fp8_kvcache
and
getattr
(
impl
,
"supports_quant_query_input"
,
False
)
# quantize_query forces concat format
query_fmt
=
"concat"
if
quantize_query
else
backend_cfg
[
"query_format"
]
# Create decode query tensors
if
has_decode
:
decode_inputs
,
_
=
_create_input_tensors
(
num_decode
,
mla_dims
,
query_fmt
,
device
,
torch
.
bfloat16
)
# Cast decode query to fp8 if the backend supports it
if
quantize_query
:
from
vllm.platforms
import
current_platform
if
isinstance
(
decode_inputs
,
tuple
):
decode_inputs
=
torch
.
cat
(
list
(
decode_inputs
),
dim
=-
1
)
decode_inputs
=
decode_inputs
.
to
(
current_platform
.
fp8_dtype
())
# Create prefill input tensors
if
has_prefill
:
_
,
prefill_inputs
=
_create_input_tensors
(
num_prefill
,
mla_dims
,
query_fmt
,
device
,
torch
.
bfloat16
)
# Build forward function
def
forward_fn
():
results
=
[]
if
has_decode
:
results
.
append
(
impl
.
forward_mqa
(
decode_inputs
,
kv_cache
,
metadata
,
layer
))
if
has_prefill
:
results
.
append
(
impl
.
forward_mha
(
prefill_inputs
[
"q"
],
prefill_inputs
[
"k_c_normed"
],
prefill_inputs
[
"k_pe"
],
kv_cache
,
metadata
,
prefill_inputs
[
"k_scale"
],
prefill_inputs
[
"output"
],
)
)
return
results
[
0
]
if
len
(
results
)
==
1
else
tuple
(
results
)
# Warmup
for
_
in
range
(
config
.
warmup_iters
):
forward_fn
()
torch
.
accelerator
.
synchronize
()
# Optionally capture a CUDA graph after warmup.
# Graph replay eliminates CPU launch overhead so timings reflect pure
# kernel time.
if
config
.
use_cuda_graphs
:
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
forward_fn
()
benchmark_fn
=
graph
.
replay
else
:
benchmark_fn
=
forward_fn
# Benchmark
times
=
[]
for
_
in
range
(
config
.
repeats
):
...
...
@@ -785,7 +864,7 @@ def _run_single_benchmark(
start
.
record
()
for
_
in
range
(
config
.
num_layers
):
forw
ar
d
_fn
()
benchm
ar
k
_fn
()
end
.
record
()
torch
.
accelerator
.
synchronize
()
...
...
@@ -852,13 +931,30 @@ def _run_mla_benchmark_batched(
# Determine if this is a sparse backend
is_sparse
=
backend_cfg
.
get
(
"is_sparse"
,
False
)
# Extract kv_cache_dtype from the first config
kv_cache_dtype
=
getattr
(
first_config
,
"kv_cache_dtype"
,
"auto"
)
# FlashMLA sparse only supports "fp8_ds_mla" internally (not generic "fp8").
# Remap here so the user can pass --kv-cache-dtype fp8 regardless of backend.
if
backend
.
upper
()
==
"FLASHMLA_SPARSE"
and
kv_cache_dtype
==
"fp8"
:
kv_cache_dtype
=
"fp8_ds_mla"
# Compute max total_q across all configs so the metadata builder buffer
# and scheduler config are large enough for all batch specs.
max_total_q
=
max
(
sum
(
r
.
q_len
for
r
in
parse_batch_spec
(
cfg
.
batch_spec
))
for
cfg
,
*
_
in
configs_with_params
)
# Create and set vLLM config for MLA (reused across all benchmarks)
vllm_config
=
create_minimal_vllm_config
(
model_name
=
"deepseek-v3"
,
# Used only for model path
block_size
=
block_size
,
max_num_batched_tokens
=
max_total_q
,
mla_dims
=
mla_dims
,
# Use custom dims from config or default
index_topk
=
index_topk
if
is_sparse
else
None
,
prefill_backend
=
prefill_backend
,
kv_cache_dtype
=
kv_cache_dtype
,
)
results
=
[]
...
...
@@ -883,7 +979,9 @@ def _run_mla_benchmark_batched(
mla_dims
,
vllm_config
,
device
,
max_num_tokens
=
max_total_q
,
index_topk
=
index_topk
if
is_sparse
else
None
,
kv_cache_dtype
=
kv_cache_dtype
,
)
# Verify the actual prefill backend matches what was requested
...
...
@@ -942,6 +1040,7 @@ def _run_mla_benchmark_batched(
mla_dims
,
device
,
indexer
=
indexer
,
kv_cache_dtype
=
kv_cache_dtype
,
)
results
.
append
(
result
)
...
...
benchmarks/attention_benchmarks/runner.py
View file @
a3a51d20
...
...
@@ -140,7 +140,7 @@ def _create_vllm_config(
cache_config
=
CacheConfig
(
block_size
=
config
.
block_size
,
cache_dtype
=
"auto"
,
cache_dtype
=
config
.
kv_cache_dtype
,
)
cache_config
.
num_gpu_blocks
=
max_num_blocks
cache_config
.
num_cpu_blocks
=
0
...
...
@@ -215,7 +215,7 @@ def _create_backend_impl(
num_kv_heads
=
config
.
num_kv_heads
,
alibi_slopes
=
None
,
sliding_window
=
None
,
kv_cache_dtype
=
"auto"
,
kv_cache_dtype
=
config
.
kv_cache_dtype
,
)
kv_cache_spec
=
FullAttentionSpec
(
...
...
@@ -288,12 +288,22 @@ def _create_input_tensors(
total_q
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
quantize_query
:
bool
=
False
,
)
->
tuple
:
"""Create Q, K, V input tensors for all layers."""
"""Create Q, K, V input tensors for all layers.
When quantize_query is True, queries are cast to fp8 to match backends
that require query/key/value dtype consistency.
"""
q_dtype
=
dtype
if
quantize_query
:
from
vllm.platforms
import
current_platform
q_dtype
=
current_platform
.
fp8_dtype
()
q_list
=
[
torch
.
randn
(
total_q
,
config
.
num_q_heads
,
config
.
head_dim
,
device
=
device
,
dtype
=
dtype
)
)
.
to
(
q_dtype
)
for
_
in
range
(
config
.
num_layers
)
]
k_list
=
[
...
...
@@ -344,10 +354,17 @@ def _create_kv_cache(
# Compute inverse permutation to get back to logical view
inv_order
=
[
stride_order
.
index
(
i
)
for
i
in
range
(
len
(
stride_order
))]
# Use fp8 dtype for cache when requested.
cache_dtype
=
dtype
if
config
.
kv_cache_dtype
==
"fp8"
:
from
vllm.platforms
import
current_platform
cache_dtype
=
current_platform
.
fp8_dtype
()
cache_list
=
[]
for
_
in
range
(
config
.
num_layers
):
# Allocate in physical layout order (contiguous in memory)
cache
=
torch
.
zeros
(
*
physical_shape
,
device
=
device
,
dtype
=
dtype
)
cache
=
torch
.
zeros
(
*
physical_shape
,
device
=
device
,
dtype
=
cache_
dtype
)
# Permute to logical view
cache
=
cache
.
permute
(
*
inv_order
)
cache_list
.
append
(
cache
)
...
...
@@ -392,6 +409,37 @@ def _run_single_benchmark(
)
torch
.
accelerator
.
synchronize
()
# Optionally capture a CUDA graph after warmup.
# Graph replay eliminates CPU launch overhead so timings reflect pure
# kernel time.
if
config
.
use_cuda_graphs
:
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
for
i
in
range
(
config
.
num_layers
):
impl
.
forward
(
layer
,
q_list
[
i
],
k_list
[
i
],
v_list
[
i
],
cache_list
[
i
],
attn_metadata
,
output
=
out
,
)
benchmark_fn
=
graph
.
replay
else
:
def
benchmark_fn
():
for
i
in
range
(
config
.
num_layers
):
impl
.
forward
(
layer
,
q_list
[
i
],
k_list
[
i
],
v_list
[
i
],
cache_list
[
i
],
attn_metadata
,
output
=
out
,
)
# Benchmark
times
=
[]
for
_
in
range
(
config
.
repeats
):
...
...
@@ -399,16 +447,7 @@ def _run_single_benchmark(
end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
start
.
record
()
for
i
in
range
(
config
.
num_layers
):
impl
.
forward
(
layer
,
q_list
[
i
],
k_list
[
i
],
v_list
[
i
],
cache_list
[
i
],
attn_metadata
,
output
=
out
,
)
benchmark_fn
()
end
.
record
()
torch
.
accelerator
.
synchronize
()
...
...
@@ -502,8 +541,12 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
common_attn_metadata
=
common_metadata
,
)
# Only quantize queries when the impl supports it
quantize_query
=
config
.
kv_cache_dtype
.
startswith
(
"fp8"
)
and
getattr
(
impl
,
"supports_quant_query_input"
,
False
)
q_list
,
k_list
,
v_list
=
_create_input_tensors
(
config
,
total_q
,
device
,
dtype
config
,
total_q
,
device
,
dtype
,
quantize_query
=
quantize_query
)
cache_list
=
_create_kv_cache
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment