Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3fb4b5fa
Commit
3fb4b5fa
authored
Mar 23, 2026
by
zhuwenwen
Browse files
Merge tag 'v0.18.0' into v0.18.0-ori
parents
bcf25339
89138b21
Changes
488
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1099 additions
and
436 deletions
+1099
-436
benchmarks/auto_tune/auto_tune.sh
benchmarks/auto_tune/auto_tune.sh
+22
-23
benchmarks/auto_tune/batch_auto_tune.sh
benchmarks/auto_tune/batch_auto_tune.sh
+1
-1
benchmarks/backend_request_func.py
benchmarks/backend_request_func.py
+0
-6
benchmarks/benchmark_topk_topp.py
benchmarks/benchmark_topk_topp.py
+474
-0
benchmarks/benchmark_utils.py
benchmarks/benchmark_utils.py
+0
-71
benchmarks/cutlass_benchmarks/utils.py
benchmarks/cutlass_benchmarks/utils.py
+0
-13
benchmarks/disagg_benchmarks/rate_limiter.py
benchmarks/disagg_benchmarks/rate_limiter.py
+0
-45
benchmarks/disagg_benchmarks/request_queue.py
benchmarks/disagg_benchmarks/request_queue.py
+0
-39
benchmarks/fused_kernels/layernorm_rms_benchmarks.py
benchmarks/fused_kernels/layernorm_rms_benchmarks.py
+2
-0
benchmarks/kernels/bench_concat_mla_q.py
benchmarks/kernels/bench_concat_mla_q.py
+98
-0
benchmarks/kernels/bench_cp_gather_fp8.py
benchmarks/kernels/bench_cp_gather_fp8.py
+153
-0
benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py
benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py
+2
-2
benchmarks/kernels/benchmark_activation.py
benchmarks/kernels/benchmark_activation.py
+2
-0
benchmarks/kernels/benchmark_block_fp8_gemm.py
benchmarks/kernels/benchmark_block_fp8_gemm.py
+2
-0
benchmarks/kernels/benchmark_cutlass_moe_fp8.py
benchmarks/kernels/benchmark_cutlass_moe_fp8.py
+24
-17
benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py
benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py
+29
-12
benchmarks/kernels/benchmark_device_communicators.py
benchmarks/kernels/benchmark_device_communicators.py
+94
-31
benchmarks/kernels/benchmark_fp8_gemm.py
benchmarks/kernels/benchmark_fp8_gemm.py
+0
-0
benchmarks/kernels/benchmark_fused_collective.py
benchmarks/kernels/benchmark_fused_collective.py
+161
-153
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
+35
-23
No files found.
Too many changes to show.
To preserve performance only
488 of 488+
files are displayed.
Plain diff
Email patch
benchmarks/auto_tune/auto_tune.sh
View file @
3fb4b5fa
...
...
@@ -46,10 +46,10 @@ echo "VLLM_LOGGING_LEVEL=$VLLM_LOGGING_LEVEL"
echo
"RESULT_FILE=
$RESULT
"
echo
"====================== AUTO TUNEPARAMETERS ===================="
rm
-rf
$LOG_FOLDER
rm
-rf
$PROFILE_PATH
mkdir
-p
$LOG_FOLDER
mkdir
-p
$PROFILE_PATH
rm
-rf
"
$LOG_FOLDER
"
rm
-rf
"
$PROFILE_PATH
"
mkdir
-p
"
$LOG_FOLDER
"
mkdir
-p
"
$PROFILE_PATH
"
cd
"
$BASE
/vllm"
...
...
@@ -85,7 +85,6 @@ start_server() {
# Each argument and its value are separate elements.
local
common_args_array
=(
"
$MODEL
"
"--disable-log-requests"
"--port"
"8004"
"--host"
"
$HOSTNAME
"
"--gpu-memory-utilization"
"
$gpu_memory_utilization
"
...
...
@@ -114,7 +113,7 @@ start_server() {
# wait for 10 minutes...
server_started
=
0
for
i
in
{
1..60
}
;
do
for
_
in
{
1..60
}
;
do
# This line checks whether the server is still alive or not,
# since that we should always have permission to send signal to the server process.
kill
-0
$server_pid
2> /dev/null
||
break
...
...
@@ -145,12 +144,12 @@ run_benchmark() {
local
vllm_log
=
"
$LOG_FOLDER
/vllm_log_
${
max_num_seqs
}
_
${
max_num_batched_tokens
}
.txt"
echo
"vllm_log:
$vllm_log
"
echo
rm
-f
$vllm_log
rm
-f
"
$vllm_log
"
pkill
-if
"vllm serve"
||
true
echo
"starting server..."
# Call start_server without a profile_dir to avoid profiling overhead
start_server
$gpu_memory_utilization
$max_num_seqs
$max_num_batched_tokens
$vllm_log
""
start_server
"
$gpu_memory_utilization
"
"
$max_num_seqs
"
"
$max_num_batched_tokens
"
"
$vllm_log
"
""
result
=
$?
if
[[
"
$result
"
-eq
1
]]
;
then
echo
"server failed to start. gpu_memory_utilization:
$gpu_memory_utilization
, max_num_seqs:
$max_num_seqs
, max_num_batched_tokens:
$max_num_batched_tokens
"
...
...
@@ -168,15 +167,15 @@ run_benchmark() {
# --profile flag is removed from this call
vllm bench serve
\
--backend
vllm
\
--model
$MODEL
\
--model
"
$MODEL
"
\
--dataset-name
random
\
--random-input-len
$adjusted_input_len
\
--random-output-len
$OUTPUT_LEN
\
--random-output-len
"
$OUTPUT_LEN
"
\
--ignore-eos
\
--disable-tqdm
\
--request-rate
inf
\
--percentile-metrics
ttft,tpot,itl,e2el
\
--goodput
e2el:
$MAX_LATENCY_ALLOWED_MS
\
--goodput
e2el:
"
$MAX_LATENCY_ALLOWED_MS
"
\
--num-prompts
1000
\
--random-prefix-len
$prefix_len
\
--host
"
$HOSTNAME
"
\
...
...
@@ -195,20 +194,20 @@ run_benchmark() {
request_rate
=
$((${
throughput
%.*
}
+
1
))
while
((
request_rate
>
0
))
;
do
# clear prefix cache
curl
-X
POST http://
${
HOSTNAME
}
:8004/reset_prefix_cache
curl
-X
POST http://
"
${
HOSTNAME
}
"
:8004/reset_prefix_cache
sleep
5
bm_log
=
"
$LOG_FOLDER
/bm_log_
${
max_num_seqs
}
_
${
max_num_batched_tokens
}
_requestrate_
${
request_rate
}
.txt"
vllm bench serve
\
--backend
vllm
\
--model
$MODEL
\
--model
"
$MODEL
"
\
--dataset-name
random
\
--random-input-len
$adjusted_input_len
\
--random-output-len
$OUTPUT_LEN
\
--random-output-len
"
$OUTPUT_LEN
"
\
--ignore-eos
\
--disable-tqdm
\
--request-rate
$request_rate
\
--percentile-metrics
ttft,tpot,itl,e2el
\
--goodput
e2el:
$MAX_LATENCY_ALLOWED_MS
\
--goodput
e2el:
"
$MAX_LATENCY_ALLOWED_MS
"
\
--num-prompts
100
\
--random-prefix-len
$prefix_len
\
--host
"
$HOSTNAME
"
\
...
...
@@ -255,7 +254,7 @@ gpu_memory_utilization=0.98
find_gpu_memory_utilization
=
0
while
((
$(
echo
"
$gpu_memory_utilization
>= 0.9"
| bc
-l
)
))
;
do
# Pass empty string for profile_dir argument
start_server
$gpu_memory_utilization
"
${
num_seqs_list
[-1]
}
"
"
${
num_batched_tokens_list
[-1]
}
"
"
$LOG_FOLDER
/vllm_log_gpu_memory_utilization_
$gpu_memory_utilization
.log"
""
start_server
"
$gpu_memory_utilization
"
"
${
num_seqs_list
[-1]
}
"
"
${
num_batched_tokens_list
[-1]
}
"
"
$LOG_FOLDER
/vllm_log_gpu_memory_utilization_
$gpu_memory_utilization
.log"
""
result
=
$?
if
[[
"
$result
"
-eq
0
]]
;
then
find_gpu_memory_utilization
=
1
...
...
@@ -274,7 +273,7 @@ fi
for
num_seqs
in
"
${
num_seqs_list
[@]
}
"
;
do
for
num_batched_tokens
in
"
${
num_batched_tokens_list
[@]
}
"
;
do
run_benchmark
$num_seqs
$num_batched_tokens
$gpu_memory_utilization
run_benchmark
"
$num_seqs
"
"
$num_batched_tokens
"
"
$gpu_memory_utilization
"
done
done
echo
"finish permutations"
...
...
@@ -285,7 +284,7 @@ echo "finish permutations"
if
((
$(
echo
"
$best_throughput
> 0"
| bc
-l
)
))
;
then
echo
echo
"Benchmark tuning finished. Now running profiling on the best configuration found..."
echo
"Best config: max_num_seqs:
$best_max_num_seqs
, max_num_batched_tokens:
$best_num_batched_tokens
, throughput:
$best_throughput
"
echo
"Best config: max_num_seqs:
$best_max_num_seqs
, max_num_batched_tokens:
$best_num_batched_tokens
, throughput:
$best_throughput
, goodput:
$best_goodput
"
echo
vllm_log
=
"
$LOG_FOLDER
/vllm_log_BEST_PROFILE.txt"
...
...
@@ -293,7 +292,7 @@ if (( $(echo "$best_throughput > 0" | bc -l) )); then
# Start server with the best params and profiling ENABLED
echo
"Starting server for profiling..."
start_server
$gpu_memory_utilization
$best_max_num_seqs
$best_num_batched_tokens
"
$vllm_log
"
"
$PROFILE_PATH
"
start_server
"
$gpu_memory_utilization
"
"
$best_max_num_seqs
"
"
$best_num_batched_tokens
"
"
$vllm_log
"
"
$PROFILE_PATH
"
# Run benchmark with the best params and the --profile flag
echo
"Running benchmark with profiling..."
...
...
@@ -301,15 +300,15 @@ if (( $(echo "$best_throughput > 0" | bc -l) )); then
adjusted_input_len
=
$((
INPUT_LEN
-
prefix_len
))
vllm bench serve
\
--backend
vllm
\
--model
$MODEL
\
--model
"
$MODEL
"
\
--dataset-name
random
\
--random-input-len
$adjusted_input_len
\
--random-output-len
$OUTPUT_LEN
\
--random-output-len
"
$OUTPUT_LEN
"
\
--ignore-eos
\
--disable-tqdm
\
--request-rate
$best_request_rate
\
--request-rate
"
$best_request_rate
"
\
--percentile-metrics
ttft,tpot,itl,e2el
\
--goodput
e2el:
$MAX_LATENCY_ALLOWED_MS
\
--goodput
e2el:
"
$MAX_LATENCY_ALLOWED_MS
"
\
--num-prompts
100
\
--random-prefix-len
$prefix_len
\
--host
"
$HOSTNAME
"
\
...
...
benchmarks/auto_tune/batch_auto_tune.sh
View file @
3fb4b5fa
...
...
@@ -64,7 +64,7 @@ for i in $(seq 0 $(($num_runs - 1))); do
else
STATUS
=
"FAILURE"
((
FAILURE_COUNT++
))
FAILED_RUNS+
=(
"Run #
$((
i+1
))
:
$(
echo
$run_object
| jq
-c
.
)
"
)
FAILED_RUNS+
=(
"Run #
$((
i+1
))
:
$(
echo
"
$run_object
"
| jq
-c
.
)
"
)
fi
RUN_OUTPUT
=
$(
<
"
$RUN_OUTPUT_FILE
"
)
...
...
benchmarks/backend_request_func.py
View file @
3fb4b5fa
...
...
@@ -649,9 +649,3 @@ ASYNC_REQUEST_FUNCS = {
"sglang"
:
async_request_openai_completions
,
"llama.cpp"
:
async_request_openai_completions
,
}
OPENAI_COMPATIBLE_BACKENDS
=
[
k
for
k
,
v
in
ASYNC_REQUEST_FUNCS
.
items
()
if
v
in
(
async_request_openai_completions
,
async_request_openai_chat_completions
)
]
benchmarks/benchmark_topk_topp.py
0 → 100644
View file @
3fb4b5fa
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark comparing Triton vs PyTorch sort-based top-k/top-p implementations.
Compares:
- apply_top_k_top_p_triton (Triton binary search)
- apply_top_k_top_p (PyTorch sort-based)
Scenarios:
- top_k only (whole batch, partial batch)
- top_p only (whole batch, partial batch)
- mix of top_k and top_p
"""
import
argparse
import
gc
from
dataclasses
import
dataclass
import
torch
from
vllm.v1.sample.ops.topk_topp_sampler
import
apply_top_k_top_p_pytorch
from
vllm.v1.sample.ops.topk_topp_triton
import
(
apply_top_k_top_p_triton
,
reset_buffer_cache
,
)
@
dataclass
class
BenchmarkConfig
:
"""Configuration for a benchmark run."""
name
:
str
batch_size
:
int
vocab_size
:
int
# k and p can be tensors or None
k_values
:
torch
.
Tensor
|
None
# [batch_size] or None
p_values
:
torch
.
Tensor
|
None
# [batch_size] or None
description
:
str
ops_pct
:
float
=
0.0
# Percentage of ops relative to batch size
def
calculate_ops_pct
(
k_values
:
torch
.
Tensor
|
None
,
p_values
:
torch
.
Tensor
|
None
,
vocab_size
:
int
,
batch_size
:
int
,
)
->
float
:
"""
Calculate the percentage of active top-k and top-p operations.
Returns percentage where 100% = batch_size ops.
E.g., if all rows have both top-k and top-p active, returns 200%.
"""
active_ops
=
0
if
k_values
is
not
None
:
# Count rows where k < vocab_size (active top-k filtering)
active_ops
+=
(
k_values
<
vocab_size
).
sum
().
item
()
if
p_values
is
not
None
:
# Count rows where p < 1.0 (active top-p filtering)
active_ops
+=
(
p_values
<
1.0
).
sum
().
item
()
return
(
active_ops
/
batch_size
)
*
100
if
batch_size
>
0
else
0.0
def
create_logits
(
batch_size
:
int
,
vocab_size
:
int
,
device
:
str
=
"cuda"
)
->
torch
.
Tensor
:
"""Create random logits mimicking a realistic LLM distribution.
Uses a Zipf-like probability distribution (rank^-1.1) converted to logits
via log, then randomly permuted per row. This produces a peaked distribution
where a small number of tokens capture most probability mass, similar to
real model outputs.
"""
# Create Zipf-like probabilities: p(rank) ~ rank^(-alpha)
ranks
=
torch
.
arange
(
1
,
vocab_size
+
1
,
dtype
=
torch
.
float32
,
device
=
device
)
probs
=
ranks
.
pow
(
-
1.1
)
probs
=
probs
/
probs
.
sum
()
# Convert to logits (log-probabilities, unnormalized is fine)
base_logits
=
probs
.
log
()
# Broadcast to batch and randomly permute each row
logits
=
base_logits
.
unsqueeze
(
0
).
expand
(
batch_size
,
-
1
).
clone
()
for
i
in
range
(
batch_size
):
logits
[
i
]
=
logits
[
i
,
torch
.
randperm
(
vocab_size
,
device
=
device
)]
return
logits
def
measure_memory
()
->
tuple
[
int
,
int
]:
"""Return (allocated, reserved) memory in bytes."""
torch
.
accelerator
.
synchronize
()
return
(
torch
.
accelerator
.
memory_allocated
(),
torch
.
accelerator
.
max_memory_allocated
(),
)
def
reset_memory_stats
():
"""Reset peak memory statistics."""
reset_buffer_cache
()
torch
.
accelerator
.
reset_peak_memory_stats
()
torch
.
accelerator
.
empty_cache
()
gc
.
collect
()
def
benchmark_function
(
func
,
logits
:
torch
.
Tensor
,
k
:
torch
.
Tensor
|
None
,
p
:
torch
.
Tensor
|
None
,
warmup_iters
:
int
=
5
,
benchmark_iters
:
int
=
20
,
)
->
tuple
[
float
,
int
]:
"""
Benchmark a function and return (avg_time_ms, peak_memory_bytes).
Returns average time in milliseconds and peak memory usage.
"""
# Warmup
for
_
in
range
(
warmup_iters
):
logits_copy
=
logits
.
clone
()
func
(
logits_copy
,
k
,
p
)
torch
.
accelerator
.
synchronize
()
# Reset memory stats before benchmark
reset_memory_stats
()
# Benchmark
start_events
=
[
torch
.
cuda
.
Event
(
enable_timing
=
True
)
for
_
in
range
(
benchmark_iters
)
]
end_events
=
[
torch
.
cuda
.
Event
(
enable_timing
=
True
)
for
_
in
range
(
benchmark_iters
)]
for
i
in
range
(
benchmark_iters
):
logits_copy
=
logits
.
clone
()
start_events
[
i
].
record
()
func
(
logits_copy
,
k
,
p
)
end_events
[
i
].
record
()
torch
.
accelerator
.
synchronize
()
# Calculate timing
times
=
[
start_events
[
i
].
elapsed_time
(
end_events
[
i
])
for
i
in
range
(
benchmark_iters
)
]
avg_time
=
sum
(
times
)
/
len
(
times
)
# Get peak memory
_
,
peak_memory
=
measure_memory
()
return
avg_time
,
peak_memory
def
create_benchmark_configs
(
batch_sizes
:
list
[
int
],
vocab_sizes
:
list
[
int
],
device
:
str
=
"cuda"
,
)
->
list
[
BenchmarkConfig
]:
"""Create all benchmark configurations."""
configs
=
[]
for
vocab_size
in
vocab_sizes
:
for
batch_size
in
batch_sizes
:
# 1. Top-k only - whole batch (all rows have k < vocab_size)
k_all
=
torch
.
full
((
batch_size
,),
50
,
dtype
=
torch
.
int32
,
device
=
device
)
configs
.
append
(
BenchmarkConfig
(
name
=
f
"topk_whole_b
{
batch_size
}
_v
{
vocab_size
//
1000
}
k"
,
batch_size
=
batch_size
,
vocab_size
=
vocab_size
,
k_values
=
k_all
,
p_values
=
None
,
description
=
f
"Top-k only (whole batch, k=50), "
f
"batch=
{
batch_size
}
, vocab=
{
vocab_size
}
"
,
ops_pct
=
calculate_ops_pct
(
k_all
,
None
,
vocab_size
,
batch_size
),
)
)
# 2. Top-k only - partial batch (half have k=50, half have k=vocab_size)
k_partial
=
torch
.
full
((
batch_size
,),
50
,
dtype
=
torch
.
int32
,
device
=
device
)
k_partial
[
batch_size
//
2
:]
=
vocab_size
# No filtering for second half
configs
.
append
(
BenchmarkConfig
(
name
=
f
"topk_partial_b
{
batch_size
}
_v
{
vocab_size
//
1000
}
k"
,
batch_size
=
batch_size
,
vocab_size
=
vocab_size
,
k_values
=
k_partial
,
p_values
=
None
,
description
=
f
"Top-k only (partial batch, 50% k=50, 50% k=vocab), "
f
"batch=
{
batch_size
}
, vocab=
{
vocab_size
}
"
,
ops_pct
=
calculate_ops_pct
(
k_partial
,
None
,
vocab_size
,
batch_size
),
)
)
# 3. Top-p only - whole batch (all rows have p < 1.0)
p_all
=
torch
.
full
((
batch_size
,),
0.9
,
dtype
=
torch
.
float32
,
device
=
device
)
configs
.
append
(
BenchmarkConfig
(
name
=
f
"topp_whole_b
{
batch_size
}
_v
{
vocab_size
//
1000
}
k"
,
batch_size
=
batch_size
,
vocab_size
=
vocab_size
,
k_values
=
None
,
p_values
=
p_all
,
description
=
f
"Top-p only (whole batch, p=0.9), "
f
"batch=
{
batch_size
}
, vocab=
{
vocab_size
}
"
,
ops_pct
=
calculate_ops_pct
(
None
,
p_all
,
vocab_size
,
batch_size
),
)
)
# 4. Top-p only - partial batch (half have p=0.9, half have p=1.0)
p_partial
=
torch
.
full
(
(
batch_size
,),
0.9
,
dtype
=
torch
.
float32
,
device
=
device
)
p_partial
[
batch_size
//
2
:]
=
1.0
# No filtering for second half
configs
.
append
(
BenchmarkConfig
(
name
=
f
"topp_partial_b
{
batch_size
}
_v
{
vocab_size
//
1000
}
k"
,
batch_size
=
batch_size
,
vocab_size
=
vocab_size
,
k_values
=
None
,
p_values
=
p_partial
,
description
=
f
"Top-p only (partial batch, 50% p=0.9, 50% p=1.0), "
f
"batch=
{
batch_size
}
, vocab=
{
vocab_size
}
"
,
ops_pct
=
calculate_ops_pct
(
None
,
p_partial
,
vocab_size
,
batch_size
),
)
)
# 5. Mix of top-k and top-p (both applied to whole batch)
k_mix
=
torch
.
full
((
batch_size
,),
100
,
dtype
=
torch
.
int32
,
device
=
device
)
p_mix
=
torch
.
full
((
batch_size
,),
0.9
,
dtype
=
torch
.
float32
,
device
=
device
)
configs
.
append
(
BenchmarkConfig
(
name
=
f
"topk_topp_whole_b
{
batch_size
}
_v
{
vocab_size
//
1000
}
k"
,
batch_size
=
batch_size
,
vocab_size
=
vocab_size
,
k_values
=
k_mix
,
p_values
=
p_mix
,
description
=
f
"Top-k + Top-p (whole batch, k=100, p=0.9), "
f
"batch=
{
batch_size
}
, vocab=
{
vocab_size
}
"
,
ops_pct
=
calculate_ops_pct
(
k_mix
,
p_mix
,
vocab_size
,
batch_size
),
)
)
# 6. Mix with partial application (some rows k only, some p only, some both)
k_mixed
=
torch
.
full
(
(
batch_size
,),
vocab_size
,
dtype
=
torch
.
int32
,
device
=
device
)
p_mixed
=
torch
.
full
((
batch_size
,),
1.0
,
dtype
=
torch
.
float32
,
device
=
device
)
# First third: k only
third
=
batch_size
//
3
k_mixed
[:
third
]
=
50
# Second third: p only
p_mixed
[
third
:
2
*
third
]
=
0.5
# Last third: both k and p
k_mixed
[
2
*
third
:]
=
100
p_mixed
[
2
*
third
:]
=
0.9
configs
.
append
(
BenchmarkConfig
(
name
=
f
"mixed_partial_b
{
batch_size
}
_v
{
vocab_size
//
1000
}
k"
,
batch_size
=
batch_size
,
vocab_size
=
vocab_size
,
k_values
=
k_mixed
,
p_values
=
p_mixed
,
description
=
f
"Mixed partial (1/3 k=50, 1/3 p=0.9, 1/3 both), "
f
"batch=
{
batch_size
}
, vocab=
{
vocab_size
}
"
,
ops_pct
=
calculate_ops_pct
(
k_mixed
,
p_mixed
,
vocab_size
,
batch_size
),
)
)
return
configs
def
format_memory
(
bytes_val
:
int
)
->
str
:
"""Format memory in human-readable form."""
if
bytes_val
>=
1024
**
3
:
return
f
"
{
bytes_val
/
(
1024
**
3
):.
2
f
}
GB"
elif
bytes_val
>=
1024
**
2
:
return
f
"
{
bytes_val
/
(
1024
**
2
):.
2
f
}
MB"
elif
bytes_val
>=
1024
:
return
f
"
{
bytes_val
/
1024
:.
2
f
}
KB"
return
f
"
{
bytes_val
}
B"
def
run_benchmark
(
configs
:
list
[
BenchmarkConfig
],
warmup_iters
:
int
=
5
,
benchmark_iters
:
int
=
20
,
verbose
:
bool
=
True
,
):
"""Run all benchmarks and print results."""
results
=
[]
print
(
"="
*
100
)
print
(
"Top-k/Top-p Benchmark: Triton vs PyTorch Sort-based"
)
print
(
"="
*
100
)
print
()
for
config
in
configs
:
if
verbose
:
print
(
f
"Running:
{
config
.
description
}
"
)
# Create fresh logits for this config
logits
=
create_logits
(
config
.
batch_size
,
config
.
vocab_size
)
# Benchmark Triton
reset_memory_stats
()
triton_time
,
triton_mem
=
benchmark_function
(
apply_top_k_top_p_triton
,
logits
,
config
.
k_values
,
config
.
p_values
,
warmup_iters
,
benchmark_iters
,
)
# Benchmark PyTorch
reset_memory_stats
()
pytorch_time
,
pytorch_mem
=
benchmark_function
(
apply_top_k_top_p_pytorch
,
logits
,
config
.
k_values
,
config
.
p_values
,
warmup_iters
,
benchmark_iters
,
)
speedup
=
pytorch_time
/
triton_time
if
triton_time
>
0
else
float
(
"inf"
)
mem_ratio
=
pytorch_mem
/
triton_mem
if
triton_mem
>
0
else
float
(
"inf"
)
result
=
{
"config"
:
config
,
"triton_time_ms"
:
triton_time
,
"pytorch_time_ms"
:
pytorch_time
,
"triton_mem"
:
triton_mem
,
"pytorch_mem"
:
pytorch_mem
,
"speedup"
:
speedup
,
"mem_ratio"
:
mem_ratio
,
}
results
.
append
(
result
)
if
verbose
:
print
(
f
" Triton:
{
triton_time
:.
3
f
}
ms,
{
format_memory
(
triton_mem
)
}
"
)
print
(
f
" PyTorch:
{
pytorch_time
:.
3
f
}
ms,
{
format_memory
(
pytorch_mem
)
}
"
)
print
(
f
" Speedup:
{
speedup
:.
2
f
}
x, Memory ratio:
{
mem_ratio
:.
2
f
}
x"
)
print
()
# Clean up
del
logits
reset_memory_stats
()
return
results
def
print_summary_table
(
results
:
list
[
dict
]):
"""Print a summary table of results."""
print
()
print
(
"="
*
130
)
print
(
"SUMMARY TABLE"
)
print
(
"="
*
130
)
print
()
# Header
header
=
(
f
"
{
'Scenario'
:
<
40
}
{
'Batch'
:
>
6
}
{
'Vocab'
:
>
7
}
{
'Ops%'
:
>
6
}
"
f
"
{
'Triton (ms)'
:
>
12
}
{
'PyTorch (ms)'
:
>
13
}
{
'Speedup'
:
>
8
}
"
f
"
{
'Tri Mem'
:
>
10
}
{
'Pyt Mem'
:
>
10
}
"
)
print
(
header
)
print
(
"-"
*
130
)
# Group by scenario type
current_vocab
=
None
for
result
in
results
:
config
=
result
[
"config"
]
# Add separator between vocab sizes
if
current_vocab
!=
config
.
vocab_size
:
if
current_vocab
is
not
None
:
print
(
"-"
*
130
)
current_vocab
=
config
.
vocab_size
scenario
=
config
.
name
.
split
(
"_b"
)[
0
]
# Extract scenario name
print
(
f
"
{
scenario
:
<
40
}
{
config
.
batch_size
:
>
6
}
{
config
.
vocab_size
:
>
7
}
"
f
"
{
config
.
ops_pct
:
>
5.0
f
}
% "
f
"
{
result
[
'triton_time_ms'
]:
>
12.3
f
}
{
result
[
'pytorch_time_ms'
]:
>
13.3
f
}
"
f
"
{
result
[
'speedup'
]:
>
7.2
f
}
x "
f
"
{
format_memory
(
result
[
'triton_mem'
]):
>
10
}
"
f
"
{
format_memory
(
result
[
'pytorch_mem'
]):
>
10
}
"
)
print
(
"="
*
130
)
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Benchmark Triton vs PyTorch sort-based top-k/top-p implementations"
)
parser
.
add_argument
(
"--batch-sizes"
,
type
=
int
,
nargs
=
"+"
,
default
=
[
1
,
4
,
16
,
64
,
128
,
512
,
1024
,
2048
],
help
=
"Batch sizes to test (default: 1 4 16 64)"
,
)
parser
.
add_argument
(
"--vocab-sizes"
,
type
=
int
,
nargs
=
"+"
,
default
=
[
32768
,
131072
],
# 32k, 128k
help
=
"Vocabulary sizes to test (default: 32768 131072)"
,
)
parser
.
add_argument
(
"--warmup-iters"
,
type
=
int
,
default
=
5
,
help
=
"Number of warmup iterations (default: 5)"
,
)
parser
.
add_argument
(
"--benchmark-iters"
,
type
=
int
,
default
=
20
,
help
=
"Number of benchmark iterations (default: 20)"
,
)
parser
.
add_argument
(
"--quiet"
,
action
=
"store_true"
,
help
=
"Only print summary table"
,
)
args
=
parser
.
parse_args
()
# Print configuration
print
(
f
"Batch sizes:
{
args
.
batch_sizes
}
"
)
print
(
f
"Vocab sizes:
{
args
.
vocab_sizes
}
"
)
print
(
f
"Warmup iterations:
{
args
.
warmup_iters
}
"
)
print
(
f
"Benchmark iterations:
{
args
.
benchmark_iters
}
"
)
print
()
# Check CUDA
if
not
torch
.
cuda
.
is_available
():
print
(
"ERROR: CUDA is not available. This benchmark requires a GPU."
)
return
device_name
=
torch
.
cuda
.
get_device_name
(
0
)
print
(
f
"GPU:
{
device_name
}
"
)
print
()
# Create configs
configs
=
create_benchmark_configs
(
args
.
batch_sizes
,
args
.
vocab_sizes
,
)
# Run benchmarks
results
=
run_benchmark
(
configs
,
warmup_iters
=
args
.
warmup_iters
,
benchmark_iters
=
args
.
benchmark_iters
,
verbose
=
not
args
.
quiet
,
)
# Print summary
print_summary_table
(
results
)
if
__name__
==
"__main__"
:
main
()
benchmarks/benchmark_utils.py
View file @
3fb4b5fa
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
json
import
math
import
os
import
time
from
types
import
TracebackType
from
typing
import
Any
def
convert_to_pytorch_benchmark_format
(
args
:
argparse
.
Namespace
,
metrics
:
dict
[
str
,
list
],
extra_info
:
dict
[
str
,
Any
]
)
->
list
:
"""
Save the benchmark results in the format used by PyTorch OSS benchmark with
on metric per record
https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
"""
records
=
[]
if
not
os
.
environ
.
get
(
"SAVE_TO_PYTORCH_BENCHMARK_FORMAT"
,
False
):
return
records
for
name
,
benchmark_values
in
metrics
.
items
():
record
=
{
"benchmark"
:
{
"name"
:
"vLLM benchmark"
,
"extra_info"
:
{
"args"
:
vars
(
args
),
},
},
"model"
:
{
"name"
:
args
.
model
,
},
"metric"
:
{
"name"
:
name
,
"benchmark_values"
:
benchmark_values
,
"extra_info"
:
extra_info
,
},
}
tp
=
record
[
"benchmark"
][
"extra_info"
][
"args"
].
get
(
"tensor_parallel_size"
)
# Save tensor_parallel_size parameter if it's part of the metadata
if
not
tp
and
"tensor_parallel_size"
in
extra_info
:
record
[
"benchmark"
][
"extra_info"
][
"args"
][
"tensor_parallel_size"
]
=
(
extra_info
[
"tensor_parallel_size"
]
)
records
.
append
(
record
)
return
records
class
InfEncoder
(
json
.
JSONEncoder
):
def
clear_inf
(
self
,
o
:
Any
):
if
isinstance
(
o
,
dict
):
return
{
k
:
self
.
clear_inf
(
v
)
for
k
,
v
in
o
.
items
()}
elif
isinstance
(
o
,
list
):
return
[
self
.
clear_inf
(
v
)
for
v
in
o
]
elif
isinstance
(
o
,
float
)
and
math
.
isinf
(
o
):
return
"inf"
return
o
def
iterencode
(
self
,
o
:
Any
,
*
args
,
**
kwargs
)
->
Any
:
return
super
().
iterencode
(
self
.
clear_inf
(
o
),
*
args
,
**
kwargs
)
def
write_to_json
(
filename
:
str
,
records
:
list
)
->
None
:
with
open
(
filename
,
"w"
)
as
f
:
json
.
dump
(
records
,
f
,
cls
=
InfEncoder
,
default
=
lambda
o
:
f
"<
{
type
(
o
).
__name__
}
object is not JSON serializable>"
,
)
# Collect time and generate time metrics
...
...
benchmarks/cutlass_benchmarks/utils.py
View file @
3fb4b5fa
...
...
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Cutlass bench utils
from
collections.abc
import
Iterable
import
torch
...
...
@@ -86,15 +85,3 @@ def make_rand_sparse_tensors(
# Compressed B, Metadata, Original A, B
return
b_compressed
,
e
,
a
,
b
def
make_n_rand_sparse_tensors
(
num_tensors
:
int
,
dtype
:
torch
.
dtype
,
m
:
int
,
n
:
int
,
k
:
int
)
->
tuple
[
Iterable
[
torch
.
Tensor
],
Iterable
[
torch
.
Tensor
]]:
ABs
=
[]
for
_
in
range
(
num_tensors
):
b_comp
,
e
,
a
,
b
=
make_rand_sparse_tensors
(
dtype
,
m
,
n
,
k
)
if
b_comp
is
not
None
:
ABs
.
append
(
make_rand_sparse_tensors
(
dtype
,
m
,
n
,
k
))
BComps
,
Es
,
As
,
Bs
=
zip
(
*
ABs
)
return
list
(
BComps
),
list
(
Es
),
list
(
As
),
list
(
Bs
)
benchmarks/disagg_benchmarks/rate_limiter.py
deleted
100644 → 0
View file @
bcf25339
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
time
class
RateLimiter
:
"""Token bucket rate limiter implementation"""
def
__init__
(
self
,
rate_limit
):
self
.
rate_limit
=
rate_limit
# Requests per second
self
.
num_available_tokens
=
rate_limit
# Available tokens
self
.
last_refill
=
time
.
monotonic
()
# Last token refill time
self
.
lock
=
asyncio
.
Lock
()
# Synchronization lock
async
def
acquire
(
self
):
"""Acquire a token from the rate limiter"""
while
True
:
async
with
self
.
lock
:
current_time
=
time
.
monotonic
()
elapsed
=
current_time
-
self
.
last_refill
# Refill num_available_tokens if more than 1 second has passed
if
elapsed
>
1.0
:
self
.
num_available_tokens
=
self
.
rate_limit
self
.
last_refill
=
current_time
# Check if num_available_tokens are available
if
self
.
num_available_tokens
>
0
:
self
.
num_available_tokens
-=
1
return
True
# Calculate wait time if no num_available_tokens available
wait_time
=
1.0
-
elapsed
await
asyncio
.
sleep
(
wait_time
)
async
def
__aenter__
(
self
):
"""Enter async context manager - acquire token"""
await
self
.
acquire
()
return
self
async
def
__aexit__
(
self
,
exc_type
,
exc_value
,
traceback
):
"""Exit async context manager - no cleanup needed"""
pass
benchmarks/disagg_benchmarks/request_queue.py
deleted
100644 → 0
View file @
bcf25339
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
from
collections
import
deque
class
RequestQueue
:
"""Request queue manager with concurrency control"""
def
__init__
(
self
,
max_concurrent
,
max_queue_size
):
# Maximum concurrent requests
self
.
max_concurrent
=
max_concurrent
self
.
max_queue_size
=
max_queue_size
# Maximum queue size
# Concurrency control
self
.
semaphore
=
asyncio
.
Semaphore
(
max_concurrent
)
self
.
queue
=
deque
()
# Request queue
self
.
queue_size
=
0
# Current queue size
self
.
lock
=
asyncio
.
Lock
()
# Sync queue Lock
async
def
enqueue
(
self
,
task
):
"""Add a request task to the queue"""
async
with
self
.
lock
:
if
self
.
queue_size
>=
self
.
max_queue_size
:
return
False
self
.
queue
.
append
(
task
)
self
.
queue_size
+=
1
return
True
async
def
process
(
self
):
"""Process queued requests using semaphore for concurrency control"""
while
True
:
if
self
.
queue
:
async
with
self
.
semaphore
,
self
.
lock
:
task
=
self
.
queue
.
popleft
()
self
.
queue_size
-=
1
await
task
await
asyncio
.
sleep
(
0.01
)
# Yield control to event loop
benchmarks/fused_kernels/layernorm_rms_benchmarks.py
View file @
3fb4b5fa
...
...
@@ -13,6 +13,7 @@ from torch.utils.benchmark import Measurement as TMeasurement
from
tqdm
import
tqdm
import
vllm._custom_ops
as
ops
from
vllm.benchmarks.lib.utils
import
default_vllm_config
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
,
...
...
@@ -291,6 +292,7 @@ def print_timers(timers: Iterable[TMeasurement]):
compare
.
print
()
@
default_vllm_config
()
def
main
():
torch
.
set_default_device
(
"cuda"
)
bench_params
=
get_bench_params
()
...
...
benchmarks/kernels/bench_concat_mla_q.py
0 → 100644
View file @
3fb4b5fa
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.triton_utils
import
triton
# DeepSeek V3 dimensions
NOPE_DIM
=
512
ROPE_DIM
=
64
NUM_HEADS
=
128
NUM_TOKENS
=
[
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
]
def
get_configs
():
return
NUM_TOKENS
def
make_inputs
(
num_tokens
,
dtype
):
"""Create inputs matching the real code path.
Args:
contiguous_nope: If False, simulate the transposed BMM output
(non-contiguous nope with stride pattern from
[N,B,L].transpose(0,1)).
"""
# Simulate: bmm output [N, B, L].transpose(0, 1) -> [B, N, L]
raw
=
torch
.
randn
(
NUM_HEADS
,
num_tokens
,
NOPE_DIM
,
dtype
=
dtype
,
device
=
"cuda"
)
ql_nope
=
raw
.
transpose
(
0
,
1
)
q_pe
=
torch
.
randn
(
num_tokens
,
NUM_HEADS
,
ROPE_DIM
,
dtype
=
dtype
,
device
=
"cuda"
)
return
ql_nope
,
q_pe
# ---- Non-contiguous nope benchmark (real code path) ----
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
],
x_vals
=
get_configs
(),
line_arg
=
"provider"
,
line_vals
=
[
"torch_cat"
,
"concat_mla_q"
],
line_names
=
[
"torch.cat"
,
"concat_mla_q (v8)"
],
styles
=
[(
"blue"
,
"--"
),
(
"green"
,
"-"
)],
ylabel
=
"Latency (us)"
,
plot_name
=
"concat_mla_q-transposed"
,
args
=
{},
)
)
def
bench_transposed
(
num_tokens
,
provider
):
dtype
=
torch
.
bfloat16
ql_nope
,
q_pe
=
make_inputs
(
num_tokens
,
dtype
)
q_out
=
torch
.
empty
(
num_tokens
,
NUM_HEADS
,
NOPE_DIM
+
ROPE_DIM
,
dtype
=
dtype
,
device
=
"cuda"
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"torch_cat"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
torch
.
cat
((
ql_nope
,
q_pe
),
dim
=-
1
),
quantiles
=
quantiles
,
rep
=
500
)
else
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
ops
.
concat_mla_q
(
ql_nope
,
q_pe
,
q_out
),
quantiles
=
quantiles
,
rep
=
500
)
return
ms
*
1000
,
max_ms
*
1000
,
min_ms
*
1000
# us
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Benchmark concat_mla_q vs torch.cat"
)
parser
.
add_argument
(
"--save-path"
,
type
=
str
,
default
=
None
,
help
=
"Path to save benchmark results"
)
args
=
parser
.
parse_args
()
print
(
"
\n
"
+
"="
*
70
)
print
(
"CONCAT MLA Q KERNEL BENCHMARKS"
)
print
(
"="
*
70
)
print
(
f
"Dimensions: nope=
{
NOPE_DIM
}
, rope=
{
ROPE_DIM
}
, heads=
{
NUM_HEADS
}
"
)
print
(
f
"Per-head output:
{
NOPE_DIM
+
ROPE_DIM
}
bf16 = "
f
"
{
(
NOPE_DIM
+
ROPE_DIM
)
*
2
}
bytes"
)
print
(
f
"num_tokens (decode=batch_size, prefill=chunk_size):
{
NUM_TOKENS
}
"
)
print
(
"="
*
70
)
print
(
"
\n
--- Non-contiguous nope inputs (transposed BMM output) ---"
)
bench_transposed
.
run
(
print_data
=
True
,
save_path
=
args
.
save_path
)
print
(
"
\n
"
+
"="
*
70
)
print
(
"Benchmarking complete!"
)
print
(
"="
*
70
)
benchmarks/kernels/bench_cp_gather_fp8.py
0 → 100644
View file @
3fb4b5fa
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
math
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.triton_utils
import
triton
# DeepSeek V3 MLA dimensions
NOPE_DIM
=
512
ROPE_DIM
=
64
HEAD_DIM
=
NOPE_DIM
+
ROPE_DIM
# 576 BF16 output elements per token
ENTRY_BYTES
=
656
# 512 FP8 + 16 scales + 128 BF16 RoPE
BLOCK_SIZE
=
64
# tokens per physical cache block - get_supported_kernel_block_sizes
# Realistic prefill scenarios:
# - 1 long prefill: single request, 16K-96K tokens
# - 4 medium prefills: 4 requests, 4K-24K tokens each
# - 16 shorter prefills: 16 requests, 1K-6K tokens each
SCENARIOS
=
[
# (label, num_reqs, total_tokens_list)
(
"1-req"
,
1
,
[
8192
,
16384
,
32768
,
65536
,
98304
]),
(
"4-reqs"
,
4
,
[
8192
,
16384
,
32768
,
65536
,
98304
]),
(
"16-reqs"
,
16
,
[
8192
,
16384
,
32768
,
65536
,
98304
]),
]
def
make_inputs
(
total_tokens
,
num_reqs
,
block_size
):
"""Create synthetic FP8 cache, block table, and output buffer.
Fills the cache with random bytes (we only measure throughput,
not correctness). Block table maps each request to contiguous
physical blocks.
"""
# Divide tokens evenly across requests
base_len
=
total_tokens
//
num_reqs
remainder
=
total_tokens
%
num_reqs
seq_lens
=
[
base_len
+
(
1
if
r
<
remainder
else
0
)
for
r
in
range
(
num_reqs
)]
# workspace_starts: cumulative sum of seq_lens
workspace_starts
=
[
0
]
*
num_reqs
for
r
in
range
(
1
,
num_reqs
):
workspace_starts
[
r
]
=
workspace_starts
[
r
-
1
]
+
seq_lens
[
r
-
1
]
# Physical blocks needed per request
blocks_per_req
=
[
math
.
ceil
(
s
/
block_size
)
for
s
in
seq_lens
]
total_blocks
=
sum
(
blocks_per_req
)
max_blocks
=
max
(
blocks_per_req
)
# Allocate cache with random data (content doesn't matter for perf)
cache
=
torch
.
randint
(
0
,
256
,
(
total_blocks
,
block_size
,
ENTRY_BYTES
),
dtype
=
torch
.
uint8
,
device
=
"cuda"
,
)
# Block table: contiguous block assignments
block_table
=
torch
.
zeros
(
num_reqs
,
max_blocks
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
block_idx
=
0
for
r
in
range
(
num_reqs
):
for
b
in
range
(
blocks_per_req
[
r
]):
block_table
[
r
,
b
]
=
block_idx
block_idx
+=
1
# Output workspace
dst
=
torch
.
zeros
(
total_tokens
,
HEAD_DIM
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
seq_lens_t
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
workspace_starts_t
=
torch
.
tensor
(
workspace_starts
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
return
cache
,
dst
,
block_table
,
seq_lens_t
,
workspace_starts_t
def
bench_scenario
(
label
,
num_reqs
,
total_tokens_list
,
save_path
):
"""Run benchmark for a specific (num_reqs, total_tokens) scenario."""
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"total_tokens"
],
x_vals
=
total_tokens_list
,
line_arg
=
"provider"
,
line_vals
=
[
"cuda_kernel"
],
line_names
=
[
"cp_gather_fp8 (CUDA)"
],
styles
=
[(
"green"
,
"-"
)],
ylabel
=
"Latency (us)"
,
plot_name
=
f
"cp_gather_fp8-
{
label
}
-bs
{
BLOCK_SIZE
}
"
,
args
=
{
"num_reqs"
:
num_reqs
},
)
)
def
bench_fn
(
total_tokens
,
provider
,
num_reqs
):
cache
,
dst
,
block_table
,
seq_lens_t
,
ws_starts
=
make_inputs
(
total_tokens
,
num_reqs
,
BLOCK_SIZE
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
ops
.
cp_gather_and_upconvert_fp8_kv_cache
(
cache
,
dst
,
block_table
,
seq_lens_t
,
ws_starts
,
num_reqs
),
quantiles
=
quantiles
,
rep
=
500
,
)
return
ms
*
1000
,
max_ms
*
1000
,
min_ms
*
1000
# us
seq_len_per_req
=
total_tokens_list
[
0
]
//
num_reqs
seq_len_per_req_max
=
total_tokens_list
[
-
1
]
//
num_reqs
print
(
f
"
\n
---
{
label
}
:
{
num_reqs
}
request(s), "
f
"~
{
seq_len_per_req
}
-
{
seq_len_per_req_max
}
tokens/req ---"
)
bench_fn
.
run
(
print_data
=
True
,
save_path
=
save_path
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Benchmark cp_gather_and_upconvert_fp8_kv_cache"
)
parser
.
add_argument
(
"--save-path"
,
type
=
str
,
default
=
None
,
help
=
"Path to save benchmark results as CSV"
,
)
args
=
parser
.
parse_args
()
# Print data volume info for bandwidth analysis
read_per_token
=
ENTRY_BYTES
# 656 bytes from cache
write_per_token
=
HEAD_DIM
*
2
# 576 * 2 = 1152 bytes to workspace
total_per_token
=
read_per_token
+
write_per_token
# 1808 bytes
print
(
"
\n
"
+
"="
*
70
)
print
(
"CP_GATHER_AND_UPCONVERT_FP8_KV_CACHE BENCHMARKS"
)
print
(
"="
*
70
)
print
(
f
"Cache entry:
{
ENTRY_BYTES
}
bytes (512 FP8 + 16 scales + 128 RoPE)"
)
print
(
f
"Output row:
{
HEAD_DIM
}
BF16 =
{
HEAD_DIM
*
2
}
bytes"
)
print
(
f
"Per token:
{
total_per_token
}
bytes (read + write)"
)
print
(
f
"Block size:
{
BLOCK_SIZE
}
tokens/block"
)
print
(
"="
*
70
)
for
label
,
num_reqs
,
total_tokens_list
in
SCENARIOS
:
bench_scenario
(
label
,
num_reqs
,
total_tokens_list
,
args
.
save_path
)
print
(
"
\n
"
+
"="
*
70
)
print
(
"Benchmarking complete!"
)
print
(
"="
*
70
)
benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py
View file @
3fb4b5fa
...
...
@@ -168,7 +168,7 @@ def bench_impl(
# warmup
for
kwargs
in
kwargs_list
:
impl_type
.
get_impl
()(
**
kwargs
)
torch
.
cuda
.
synchronize
()
torch
.
accelerator
.
synchronize
()
# Merge into a single kwargs and qualify arguments as ArgPool
kwargs
=
{
k
:
ArgPool
([])
for
k
in
kwargs_list
[
0
]}
...
...
@@ -202,7 +202,7 @@ def test_correctness(T: int, N: int):
# reference output
ref_out_q
,
ref_out_s
=
output_from_impl
(
ImplType
.
REFERENCE
)
# test ou
p
tut
# test out
p
ut
out_q
,
out_s
=
output_from_impl
(
ImplType
.
SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR
)
...
...
benchmarks/kernels/benchmark_activation.py
View file @
3fb4b5fa
...
...
@@ -7,6 +7,7 @@ import itertools
import
torch
import
vllm.model_executor.layers.activation
# noqa F401
from
vllm.benchmarks.lib.utils
import
default_vllm_config
from
vllm.model_executor.custom_op
import
op_registry
from
vllm.triton_utils
import
triton
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
...
...
@@ -18,6 +19,7 @@ intermediate_size = [3072, 9728, 12288]
configs
=
list
(
itertools
.
product
(
batch_size_range
,
seq_len_range
,
intermediate_size
))
@
default_vllm_config
()
def
benchmark_activation
(
batch_size
:
int
,
seq_len
:
int
,
...
...
benchmarks/kernels/bench_block_fp8_gemm.py
→
benchmarks/kernels/bench
mark
_block_fp8_gemm.py
View file @
3fb4b5fa
...
...
@@ -8,6 +8,7 @@ os.environ["VLLM_USE_DEEP_GEMM"] = "0"
import
torch
from
vllm.benchmarks.lib.utils
import
default_vllm_config
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
W8A8BlockFp8LinearOp
,
)
...
...
@@ -40,6 +41,7 @@ DEEPSEEK_V3_SHAPES = [
]
@
default_vllm_config
()
def
build_w8a8_block_fp8_runner
(
M
,
N
,
K
,
block_size
,
device
,
use_cutlass
):
"""Build runner function for w8a8 block fp8 matmul."""
factor_for_scale
=
1e-2
...
...
benchmarks/kernels/benchmark_cutlass_moe_fp8.py
View file @
3fb4b5fa
...
...
@@ -11,12 +11,13 @@ import torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
tests.kernels.moe.utils
import
make_dummy_moe_config
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.activation
import
MoEActivation
from
vllm.model_executor.layers.fused_moe.all2all_utils
import
(
maybe_make_prepare_finalize
,
)
from
vllm.model_executor.layers.fused_moe.config
import
fp8_w8a8_moe_quant_config
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
CutlassExpertsFp8
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
,
fused_topk
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.v1.worker.workspace
import
init_workspace_manager
...
...
@@ -63,7 +64,7 @@ def bench_run(
per_out_ch
:
bool
,
mkn
:
tuple
[
int
,
int
,
int
],
):
init_workspace_manager
(
torch
.
cuda
.
current_device
())
init_workspace_manager
(
torch
.
accelerator
.
current_device
_index
())
(
m
,
k
,
n
)
=
mkn
dtype
=
torch
.
half
...
...
@@ -136,15 +137,21 @@ def bench_run(
per_out_ch_quant
=
per_out_ch
,
)
fn
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
moe_config
=
make_dummy_moe_config
(
num_experts
=
num_experts
,
hidden_dim
=
k
,
intermediate_size_per_partition
=
n
,
in_dtype
=
a
.
dtype
,
)
fn
=
mk
.
FusedMoEKernel
(
maybe_make_prepare_finalize
(
moe
=
moe_config
,
quant_config
=
quant_config
,
allow_new_interface
=
True
,
use_monolithic
=
False
,
),
CutlassExpertsFp8
(
moe_config
=
make_dummy_moe_config
(
num_experts
=
num_experts
,
hidden_dim
=
k
,
intermediate_size_per_partition
=
n
,
in_dtype
=
a
.
dtype
,
),
moe_config
=
moe_config
,
quant_config
=
quant_config
,
),
)
...
...
@@ -161,10 +168,10 @@ def bench_run(
w2_fp8q_cutlass
,
topk_weights
,
topk_ids
,
activation
=
"silu"
,
activation
=
MoEActivation
.
SILU
,
global_num_experts
=
num_experts
,
)
torch
.
cuda
.
synchronize
()
torch
.
accelerator
.
synchronize
()
# Create CUDA graphs for Triton (match benchmark_moe.py pattern exactly)
triton_stream
=
torch
.
cuda
.
Stream
()
...
...
@@ -180,14 +187,14 @@ def bench_run(
topk_ids
,
quant_config
=
quant_config
,
)
torch
.
cuda
.
synchronize
()
torch
.
accelerator
.
synchronize
()
def
bench_cuda_graph
(
graph
,
num_warmup
=
5
,
num_iters
=
100
):
"""Benchmark CUDA graph using events like benchmark_moe.py"""
# Warmup
for
_
in
range
(
num_warmup
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
torch
.
accelerator
.
synchronize
()
# Timing
start_event
=
torch
.
Event
(
enable_timing
=
True
)
...
...
@@ -195,7 +202,7 @@ def bench_run(
latencies
=
[]
for
_
in
range
(
num_iters
):
torch
.
cuda
.
synchronize
()
torch
.
accelerator
.
synchronize
()
start_event
.
record
()
graph
.
replay
()
end_event
.
record
()
...
...
benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py
View file @
3fb4b5fa
...
...
@@ -15,6 +15,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from
tests.kernels.moe.utils
import
make_dummy_moe_config
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe.all2all_utils
import
(
maybe_make_prepare_finalize
,
)
from
vllm.model_executor.layers.fused_moe.config
import
(
fp8_w8a8_moe_quant_config
,
nvfp4_moe_quant_config
,
...
...
@@ -23,9 +26,6 @@ from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
,
fused_topk
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
,
)
from
vllm.scalar_type
import
scalar_types
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.v1.worker.workspace
import
init_workspace_manager
...
...
@@ -196,10 +196,21 @@ def bench_run(
g2_alphas
=
w2_gs
,
)
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(
defer_input_quant
=
True
),
moe_config
=
make_dummy_moe_config
(
num_experts
=
num_experts
,
hidden_dim
=
k
,
intermediate_size_per_partition
=
n
,
in_dtype
=
a
.
dtype
,
)
kernel
=
mk
.
FusedMoEKernel
(
maybe_make_prepare_finalize
(
moe
=
moe_config
,
quant_config
=
quant_config
,
allow_new_interface
=
True
,
use_monolithic
=
False
,
),
CutlassExpertsFp4
(
m
ake_dummy_
moe_config
()
,
m
oe_config
=
moe_config
,
quant_config
=
quant_config
,
),
)
...
...
@@ -240,11 +251,17 @@ def bench_run(
g1_alphas
=
w1_gs
,
g2_alphas
=
w2_gs
,
)
moe_config
=
make_dummy_moe_config
()
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(
defer_input_quant
=
True
),
kernel
=
mk
.
FusedMoEKernel
(
maybe_make_prepare_finalize
(
moe
=
moe_config
,
quant_config
=
quant_config
,
allow_new_interface
=
True
,
use_monolithic
=
False
,
),
CutlassExpertsFp4
(
m
ake_dummy_
moe_config
()
,
m
oe_config
=
moe_config
,
quant_config
=
quant_config
,
),
)
...
...
@@ -290,7 +307,7 @@ def bench_run(
def
replay_graph
(
graph
,
num_repeats
):
for
_
in
range
(
num_repeats
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
torch
.
accelerator
.
synchronize
()
cutlass_stream
=
torch
.
cuda
.
Stream
()
cutlass_graph
=
torch
.
cuda
.
CUDAGraph
()
...
...
@@ -313,7 +330,7 @@ def bench_run(
e
=
num_experts
,
device
=
device
,
)
torch
.
cuda
.
synchronize
()
torch
.
accelerator
.
synchronize
()
triton_stream
=
torch
.
cuda
.
Stream
()
triton_graph
=
torch
.
cuda
.
CUDAGraph
()
...
...
@@ -328,7 +345,7 @@ def bench_run(
w2_fp8scale
,
a_fp8_scale
,
)
torch
.
cuda
.
synchronize
()
torch
.
accelerator
.
synchronize
()
min_run_time
=
5
num_warmup
=
5
...
...
benchmarks/kernels/benchmark_device_communicators.py
View file @
3fb4b5fa
...
...
@@ -30,6 +30,9 @@ import torch.distributed as dist
from
torch.distributed
import
ProcessGroup
from
vllm.distributed.device_communicators.custom_all_reduce
import
CustomAllreduce
from
vllm.distributed.device_communicators.flashinfer_all_reduce
import
(
FlashInferAllReduce
,
)
from
vllm.distributed.device_communicators.pynccl
import
(
PyNcclCommunicator
,
register_nccl_symmetric_ops
,
...
...
@@ -44,7 +47,7 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser
logger
=
init_logger
(
__name__
)
# Default sequence lengths to benchmark
DEFAULT_SEQUENCE_LENGTHS
=
[
128
,
512
,
1024
,
2048
,
4096
,
8192
]
DEFAULT_SEQUENCE_LENGTHS
=
[
16
,
64
,
128
,
512
,
1024
,
2048
,
4096
,
8192
]
# Fixed hidden size and dtype for all benchmarks
HIDDEN_SIZE
=
8192
...
...
@@ -81,6 +84,7 @@ class CommunicatorBenchmark:
self
.
symm_mem_comm
=
None
self
.
symm_mem_comm_multimem
=
None
self
.
symm_mem_comm_two_shot
=
None
self
.
fi_ar_comm
=
None
self
.
_init_communicators
()
...
...
@@ -161,6 +165,22 @@ class CommunicatorBenchmark:
)
self
.
symm_mem_comm_two_shot
=
None
try
:
self
.
fi_ar_comm
=
FlashInferAllReduce
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
)
if
not
self
.
fi_ar_comm
.
disabled
:
logger
.
info
(
"Rank %s: FlashInferAllReduce initialized"
,
self
.
rank
)
else
:
logger
.
info
(
"Rank %s: FlashInferAllReduce disabled"
,
self
.
rank
)
self
.
fi_ar_comm
=
None
except
Exception
as
e
:
logger
.
warning
(
"Rank %s: Failed to initialize FlashInferAllReduce: %s"
,
self
.
rank
,
e
)
self
.
fi_ar_comm
=
None
def
benchmark_allreduce
(
self
,
sequence_length
:
int
,
num_warmup
:
int
,
num_trials
:
int
)
->
dict
[
str
,
float
]:
...
...
@@ -180,7 +200,8 @@ class CommunicatorBenchmark:
lambda
t
,
c
=
comm
:
c
.
custom_all_reduce
(
t
),
lambda
t
,
c
=
comm
:
c
.
should_custom_ar
(
t
),
comm
.
capture
(),
"1stage"
,
# env variable value
{
"VLLM_CUSTOM_ALLREDUCE_ALGO"
:
"1stage"
},
None
,
# no destroy function
)
)
# CustomAllreduce two-shot
...
...
@@ -190,7 +211,8 @@ class CommunicatorBenchmark:
lambda
t
,
c
=
comm
:
c
.
custom_all_reduce
(
t
),
lambda
t
,
c
=
comm
:
c
.
should_custom_ar
(
t
),
comm
.
capture
(),
"2stage"
,
# env variable value
{
"VLLM_CUSTOM_ALLREDUCE_ALGO"
:
"2stage"
},
None
,
# no destroy function
)
)
...
...
@@ -202,7 +224,8 @@ class CommunicatorBenchmark:
lambda
t
,
c
=
comm
:
c
.
all_reduce
(
t
),
lambda
t
:
True
,
# Always available if initialized
nullcontext
(),
None
,
# no env variable needed
{},
# no env variable needed
None
,
# no destroy function
)
)
communicators
.
append
(
...
...
@@ -211,7 +234,8 @@ class CommunicatorBenchmark:
lambda
t
:
torch
.
ops
.
vllm
.
all_reduce_symmetric_with_copy
(
t
),
lambda
t
:
True
,
# Always available if initialized
nullcontext
(),
None
,
# no env variable needed
{},
# no env variable needed
None
,
# no destroy function
)
)
...
...
@@ -223,7 +247,8 @@ class CommunicatorBenchmark:
lambda
t
,
c
=
comm
:
c
.
all_reduce
(
t
),
lambda
t
,
c
=
comm
:
c
.
should_use_symm_mem
(
t
),
nullcontext
(),
None
,
# no env variable needed
{},
# no env variable needed
None
,
# no destroy function
)
)
...
...
@@ -235,29 +260,67 @@ class CommunicatorBenchmark:
lambda
t
,
c
=
comm
:
c
.
all_reduce
(
t
),
lambda
t
,
c
=
comm
:
c
.
should_use_symm_mem
(
t
),
nullcontext
(),
None
,
# no env variable needed
{},
# no env variable needed
None
,
# no destroy function needed
)
)
# Benchmark each communicator
for
name
,
allreduce_fn
,
should_use_fn
,
context
,
env_var
in
communicators
:
# Set environment variable if needed
if
env_var
is
not
None
:
os
.
environ
[
"VLLM_CUSTOM_ALLREDUCE_ALGO"
]
=
env_var
else
:
# Clear the environment variable to avoid interference
os
.
environ
.
pop
(
"VLLM_CUSTOM_ALLREDUCE_ALGO"
,
None
)
latency
=
self
.
benchmark_allreduce_single
(
sequence_length
,
allreduce_fn
,
should_use_fn
,
context
,
num_warmup
,
num_trials
,
if
self
.
fi_ar_comm
is
not
None
:
comm
=
self
.
fi_ar_comm
communicators
.
append
(
(
"flashinfer_trtllm"
,
lambda
t
,
c
=
comm
:
c
.
all_reduce
(
t
),
lambda
t
,
c
=
comm
:
c
.
should_use_fi_ar
(
t
),
nullcontext
(),
{
"VLLM_FLASHINFER_ALLREDUCE_BACKEND"
:
"trtllm"
},
lambda
c
=
comm
:
c
.
destroy
(),
)
)
if
latency
is
not
None
:
results
[
name
]
=
latency
communicators
.
append
(
(
"flashinfer_mnnvl"
,
lambda
t
,
c
=
comm
:
c
.
all_reduce
(
t
),
lambda
t
,
c
=
comm
:
c
.
should_use_fi_ar
(
t
),
nullcontext
(),
{
"VLLM_FLASHINFER_ALLREDUCE_BACKEND"
:
"mnnvl"
},
lambda
c
=
comm
:
c
.
destroy
(),
)
)
# Benchmark each communicator
for
(
name
,
allreduce_fn
,
should_use_fn
,
context
,
env_dict
,
destroy_fn
,
)
in
communicators
:
# Save original values and apply new environment variables
saved_env
=
{
key
:
os
.
environ
.
get
(
key
)
for
key
in
env_dict
}
for
key
,
value
in
env_dict
.
items
():
os
.
environ
[
key
]
=
value
try
:
latency
=
self
.
benchmark_allreduce_single
(
sequence_length
,
allreduce_fn
,
should_use_fn
,
context
,
num_warmup
,
num_trials
,
)
if
latency
is
not
None
:
results
[
name
]
=
latency
finally
:
if
destroy_fn
is
not
None
:
destroy_fn
()
# Restore environment variables to their original state
for
key
,
original_value
in
saved_env
.
items
():
if
original_value
is
None
:
os
.
environ
.
pop
(
key
,
None
)
else
:
os
.
environ
[
key
]
=
original_value
return
results
...
...
@@ -279,7 +342,7 @@ class CommunicatorBenchmark:
if
not
should_use_fn
(
tensor
):
return
None
torch
.
cuda
.
synchronize
()
torch
.
accelerator
.
synchronize
()
stream
=
torch
.
cuda
.
Stream
()
with
torch
.
cuda
.
stream
(
stream
):
graph_input
=
tensor
.
clone
()
...
...
@@ -297,17 +360,17 @@ class CommunicatorBenchmark:
for
_
in
range
(
CUDA_GRAPH_CAPTURE_CYCLES
):
allreduce_fn
(
graph_input
)
torch
.
cuda
.
synchronize
()
torch
.
accelerator
.
synchronize
()
for
_
in
range
(
num_warmup
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
torch
.
accelerator
.
synchronize
()
torch
.
cuda
.
synchronize
()
torch
.
accelerator
.
synchronize
()
start_time
=
time
.
perf_counter
()
for
_
in
range
(
num_trials
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
torch
.
accelerator
.
synchronize
()
end_time
=
time
.
perf_counter
()
...
...
@@ -432,7 +495,7 @@ def main():
# Set device
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
torch
.
accelerator
.
set_device
_index
(
device
)
# Get CPU process group
cpu_group
=
dist
.
new_group
(
backend
=
"gloo"
)
...
...
benchmarks/kernels/bench_fp8_gemm.py
→
benchmarks/kernels/bench
mark
_fp8_gemm.py
View file @
3fb4b5fa
File moved
benchmarks/kernels/benchmark_fused_collective.py
View file @
3fb4b5fa
...
...
@@ -5,8 +5,11 @@
Benchmark for FlashInfer fused collective operations vs standard operations.
This benchmark compares:
1. FlashInfer's trtllm_allreduce_fusion (fused allreduce + rmsnorm + optional quant)
2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations
1. FlashInfer's allreduce_fusion with trtllm backend
(fused allreduce + rmsnorm + optional FP8/FP4 quant)
2. FlashInfer's allreduce_fusion with mnnvl backend
(fused allreduce + rmsnorm only, no quantization support)
3. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations
Usage with torchrun:
torchrun --nproc_per_node=2 benchmark_fused_collective.py
...
...
@@ -24,7 +27,6 @@ import torch.distributed as dist # type: ignore
from
vllm.config.vllm
import
CompilationConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.distributed
import
(
get_tp_group
,
tensor_model_parallel_all_reduce
,
)
from
vllm.distributed.parallel_state
import
(
...
...
@@ -49,14 +51,19 @@ SCALED_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant
logger
=
init_logger
(
__name__
)
# Try to import FlashInfer
TorchDistBackend
=
None
try
:
import
flashinfer.comm
as
flashinfer_comm
# type: ignore
from
flashinfer.comm.mnnvl
import
(
# type: ignore
TorchDistBackend
,
)
if
not
hasattr
(
flashinfer_comm
,
"trtllm_allreduce_fusion"
):
if
not
(
hasattr
(
flashinfer_comm
,
"allreduce_fusion"
)
and
hasattr
(
flashinfer_comm
,
"create_allreduce_fusion_workspace"
)
):
flashinfer_comm
=
None
logger
.
warning
(
"FlashInfer comm module found but missing trtllm_allreduce_fusion"
)
logger
.
warning
(
"FlashInfer comm module found but missing allreduce_fusion API"
)
except
ImportError
:
flashinfer_comm
=
None
logger
.
warning
(
"FlashInfer not found, only benchmarking standard operations"
)
...
...
@@ -74,57 +81,70 @@ _FI_MAX_SIZES = {
8
:
64
*
MiB
,
# 64MB
}
# Global workspace tensor for FlashInfer
_FI_WORKSPACE_TENSOR
=
None
# Global workspace tensors for FlashInfer (keyed by backend name)
_FI_WORKSPACES
:
dict
=
{}
# Backends to benchmark
FLASHINFER_BACKENDS
=
[
"trtllm"
,
"mnnvl"
]
def
setup_flashinfer_workspace
(
backend
:
str
,
world_size
:
int
,
rank
:
int
,
hidden_dim
:
int
,
max_token_num
:
int
,
use_fp32_lamport
:
bool
=
Fals
e
,
dtype
:
torch
.
dtyp
e
,
):
"""Setup FlashInfer workspace for fused allreduce operations."""
global
_
FI_WORKSPACE
_TENSOR
global
FI_WORKSPACE
S
if
flashinfer_comm
is
None
:
return
None
,
None
return
None
if
world_size
not
in
_FI_MAX_SIZES
:
logger
.
warning
(
"FlashInfer not supported for world size %s"
,
world_size
)
return
None
,
None
return
None
try
:
# Create IPC workspace
ipc_handles
,
workspace_tensor
=
(
flashinfer_comm
.
trtllm_create_ipc_workspace_for_all_reduce_fusion
(
tp_rank
=
rank
,
tp_size
=
world_size
,
max_token_num
=
max_token_num
,
hidden_dim
=
hidden_dim
,
group
=
get_tp_group
().
device_group
,
use_fp32_lamport
=
use_fp32_lamport
,
)
kwargs
=
{}
if
TorchDistBackend
is
not
None
:
kwargs
[
"comm_backend"
]
=
TorchDistBackend
(
group
=
dist
.
group
.
WORLD
)
workspace
=
flashinfer_comm
.
create_allreduce_fusion_workspace
(
backend
=
backend
,
world_size
=
world_size
,
rank
=
rank
,
max_token_num
=
max_token_num
,
hidden_dim
=
hidden_dim
,
dtype
=
dtype
,
**
kwargs
,
)
_FI_WORKSPACE
_TENSOR
=
workspace
_tensor
return
ipc_handles
,
workspace
_tensor
_FI_WORKSPACE
S
[
backend
]
=
workspace
return
workspace
except
Exception
as
e
:
logger
.
error
(
"Failed to setup FlashInfer workspace: %s"
,
e
)
return
None
,
None
logger
.
error
(
"Failed to setup FlashInfer workspace (backend=%s): %s"
,
backend
,
e
)
return
None
def
cleanup_flashinfer_workspace
(
ipc_handle
s
):
"""Cleanup FlashInfer workspace."""
if
flashinfer_comm
is
None
or
ipc_handles
is
None
:
def
cleanup_flashinfer_workspaces
(
):
"""Cleanup
all
FlashInfer workspace
s
."""
if
flashinfer_comm
is
None
:
return
try
:
group
=
get_tp_group
().
device_group
flashinfer_comm
.
trtllm_destroy_ipc_workspace_for_all_reduce
(
ipc_handles
,
group
)
except
Exception
as
e
:
logger
.
error
(
"Failed to cleanup FlashInfer workspace: %s"
,
e
)
for
backend
,
workspace
in
_FI_WORKSPACES
.
items
():
try
:
workspace
.
destroy
()
except
Exception
as
e
:
logger
.
error
(
"Failed to cleanup FlashInfer workspace (backend=%s): %s"
,
backend
,
e
,
)
_FI_WORKSPACES
.
clear
()
class
FlashInferFusedAllReduceParams
:
...
...
@@ -132,25 +152,15 @@ class FlashInferFusedAllReduceParams:
def
__init__
(
self
,
rank
:
int
,
world_size
:
int
,
use_fp32_lamport
:
bool
=
False
,
max_token_num
:
int
=
1024
,
):
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
use_fp32_lamport
=
use_fp32_lamport
self
.
trigger_completion_at_end
=
True
self
.
launch_with_pdl
=
True
self
.
fp32_acc
=
True
self
.
max_token_num
=
max_token_num
def
get_
trtllm
_fused_allreduce_kwargs
(
self
):
def
get_
flashinfer
_fused_allreduce_kwargs
(
self
):
return
{
"world_rank"
:
self
.
rank
,
"world_size"
:
self
.
world_size
,
"launch_with_pdl"
:
self
.
launch_with_pdl
,
"trigger_completion_at_end"
:
self
.
trigger_completion_at_end
,
"fp32_acc"
:
self
.
fp32_acc
,
}
...
...
@@ -161,11 +171,12 @@ def flashinfer_fused_allreduce_rmsnorm(
rms_gamma
:
torch
.
Tensor
,
rms_eps
:
float
,
allreduce_params
:
"FlashInferFusedAllReduceParams"
,
workspace
:
object
,
use_oneshot
:
bool
,
norm_out
:
torch
.
Tensor
|
None
=
None
,
):
"""FlashInfer fused allreduce + rmsnorm operation."""
if
flashinfer_comm
is
None
or
_FI_WORKSPACE_TENSOR
is
None
:
if
flashinfer_comm
is
None
or
workspace
is
None
:
raise
RuntimeError
(
"FlashInfer not available or workspace not initialized"
)
if
norm_out
is
None
:
...
...
@@ -174,24 +185,25 @@ def flashinfer_fused_allreduce_rmsnorm(
else
:
residual_out
=
input_tensor
flashinfer_comm
.
trtllm_allreduce_fusion
(
allreduce_in
=
input_tensor
,
token_num
=
input_tensor
.
shape
[
0
],
layout_code
=
None
if
workspace
.
backend
==
"trtllm"
:
layout_code
=
flashinfer_comm
.
QuantizationSFLayout
.
SWIZZLED_128x4
flashinfer_comm
.
allreduce_fusion
(
input
=
input_tensor
,
workspace
=
workspace
,
pattern
=
flashinfer_comm
.
AllReduceFusionPattern
.
kARResidualRMSNorm
,
residual_in
=
residual
,
residual_out
=
residual_out
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
hidden_dim
=
input_tensor
.
shape
[
-
1
],
workspace_ptrs
=
_FI_WORKSPACE_TENSOR
,
pattern_code
=
flashinfer_comm
.
AllReduceFusionPattern
.
kARResidualRMSNorm
,
allreduce_out
=
None
,
quant_out
=
None
,
scale_out
=
None
,
layout_code
=
f
la
shinfer_comm
.
QuantizationSFLayout
.
SWIZZLED_128x4
,
layout_code
=
la
yout_code
,
scale_factor
=
None
,
use_oneshot
=
use_oneshot
,
**
allreduce_params
.
get_
trtllm
_fused_allreduce_kwargs
(),
**
allreduce_params
.
get_
flashinfer
_fused_allreduce_kwargs
(),
)
...
...
@@ -202,12 +214,16 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
rms_eps
:
float
,
scale_factor
:
torch
.
Tensor
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
workspace
:
object
,
use_oneshot
:
bool
=
True
,
norm_out
:
torch
.
Tensor
|
None
=
None
,
quant_out
:
torch
.
Tensor
|
None
=
None
,
):
"""FlashInfer fused allreduce + rmsnorm + FP8 quantization."""
if
flashinfer_comm
is
None
or
_FI_WORKSPACE_TENSOR
is
None
:
"""FlashInfer fused allreduce + rmsnorm + FP8 quantization.
Note: Only supported by the trtllm backend.
"""
if
flashinfer_comm
is
None
or
workspace
is
None
:
raise
RuntimeError
(
"FlashInfer not available or workspace not initialized"
)
if
norm_out
is
None
:
...
...
@@ -216,24 +232,21 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
else
:
residual_out
=
input_tensor
flashinfer_comm
.
trtllm_allreduce_fusion
(
allreduce_in
=
input_tensor
,
token_num
=
input_tensor
.
shape
[
0
],
flashinfer_comm
.
allreduce_fusion
(
input
=
input_tensor
,
workspace
=
workspace
,
pattern
=
flashinfer_comm
.
AllReduceFusionPattern
.
kARResidualRMSNormFP8Quant
,
residual_in
=
residual
,
residual_out
=
residual_out
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
hidden_dim
=
input_tensor
.
shape
[
-
1
],
workspace_ptrs
=
_FI_WORKSPACE_TENSOR
,
pattern_code
=
flashinfer_comm
.
AllReduceFusionPattern
.
kARResidualRMSNormFP8Quant
,
allreduce_out
=
None
,
quant_out
=
quant_out
,
scale_out
=
None
,
layout_code
=
flashinfer_comm
.
QuantizationSFLayout
.
SWIZZLED_128x4
,
scale_factor
=
scale_factor
,
use_oneshot
=
use_oneshot
,
**
allreduce_params
.
get_
trtllm
_fused_allreduce_kwargs
(),
**
allreduce_params
.
get_
flashinfer
_fused_allreduce_kwargs
(),
)
...
...
@@ -244,13 +257,17 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
rms_eps
:
float
,
input_global_scale
:
torch
.
Tensor
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
workspace
:
object
,
quant_out
:
torch
.
Tensor
,
use_oneshot
:
bool
,
output_scale
:
torch
.
Tensor
,
norm_out
:
torch
.
Tensor
|
None
=
None
,
):
"""FlashInfer fused allreduce + rmsnorm + FP4 quantization."""
if
flashinfer_comm
is
None
or
_FI_WORKSPACE_TENSOR
is
None
:
"""FlashInfer fused allreduce + rmsnorm + FP4 quantization.
Note: Only supported by the trtllm backend.
"""
if
flashinfer_comm
is
None
or
workspace
is
None
:
raise
RuntimeError
(
"FlashInfer not available or workspace not initialized"
)
if
norm_out
is
None
:
...
...
@@ -259,24 +276,21 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
else
:
residual_out
=
input_tensor
flashinfer_comm
.
trtllm_allreduce_fusion
(
allreduce_in
=
input_tensor
,
token_num
=
input_tensor
.
shape
[
0
],
flashinfer_comm
.
allreduce_fusion
(
input
=
input_tensor
,
workspace
=
workspace
,
pattern
=
flashinfer_comm
.
AllReduceFusionPattern
.
kARResidualRMSNormFP4Quant
,
residual_in
=
residual
,
residual_out
=
residual_out
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
hidden_dim
=
input_tensor
.
shape
[
-
1
],
workspace_ptrs
=
_FI_WORKSPACE_TENSOR
,
pattern_code
=
flashinfer_comm
.
AllReduceFusionPattern
.
kARResidualRMSNormFP4Quant
,
allreduce_out
=
None
,
quant_out
=
quant_out
,
scale_out
=
output_scale
,
layout_code
=
flashinfer_comm
.
QuantizationSFLayout
.
SWIZZLED_128x4
,
scale_factor
=
input_global_scale
,
use_oneshot
=
use_oneshot
,
**
allreduce_params
.
get_
trtllm
_fused_allreduce_kwargs
(),
**
allreduce_params
.
get_
flashinfer
_fused_allreduce_kwargs
(),
)
...
...
@@ -371,32 +385,32 @@ def benchmark_operation(
# Warmup before graph capture
for
_
in
range
(
warmup
):
operation_func
(
*
args
,
**
kwargs
)
torch
.
cuda
.
synchronize
()
torch
.
accelerator
.
synchronize
()
# Create CUDA graph
graph
=
torch
.
cuda
.
CUDAGraph
()
num_op_per_cudagraph
=
10
# Use vLLM's graph_capture to make tensor_model_parallel_all_reduce graph-safe
device
=
torch
.
device
(
f
"cuda:
{
torch
.
cuda
.
current_device
()
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
torch
.
accelerator
.
current_device
_index
()
}
"
)
with
graph_capture
(
device
=
device
),
torch
.
cuda
.
graph
(
graph
):
for
_
in
range
(
num_op_per_cudagraph
):
operation_func
(
*
args
,
**
kwargs
)
# Graph warmup
torch
.
cuda
.
synchronize
()
torch
.
accelerator
.
synchronize
()
for
_
in
range
(
warmup
):
graph
.
replay
()
# Benchmark with CUDA graph
torch
.
cuda
.
synchronize
()
torch
.
accelerator
.
synchronize
()
start_time
=
time
.
perf_counter
()
for
_
in
range
(
trials
//
num_op_per_cudagraph
):
# operation_func(*args, **kwargs)
graph
.
replay
()
torch
.
cuda
.
synchronize
()
torch
.
accelerator
.
synchronize
()
end_time
=
time
.
perf_counter
()
avg_time_ms
=
((
end_time
-
start_time
)
/
trials
)
*
1000
...
...
@@ -409,13 +423,16 @@ def run_benchmarks(
dtype
:
torch
.
dtype
,
use_residual
:
bool
,
allreduce_params
:
FlashInferFusedAllReduceParams
|
None
,
workspaces
:
dict
,
quant_modes
:
set
[
str
],
no_oneshot
:
bool
,
):
"""Run all benchmarks for given configuration.
Args:
quant_mode: "none", "fp8_only", "fp4_only", or "all"
allreduce_params: Shared parameters for FlashInfer fused allreduce.
workspaces: Dict mapping backend name ("trtllm", "mnnvl") to workspace.
quant_modes: Set of quantization modes: "none", "fp8", "fp4".
"""
(
input_tensor
,
...
...
@@ -431,18 +448,18 @@ def run_benchmarks(
rms_eps
=
1e-6
results
=
{}
vllm_fused_allreduce
=
VllmFusedAllreduce
(
hidden_dim
,
dtype
)
use_oneshot_options
=
[
False
]
if
no_oneshot
else
[
True
,
False
]
# Create RMSNorm and QuantFP8 layers once for native benchmarks
if
"none"
in
quant_modes
:
# Standard AllReduce + RMSNorm
# Re-create VllmFusedAllreduce per config so CustomOp binds the
# correct forward method (native vs custom kernel).
for
custom_op
in
[
"-rms_norm"
,
"+rms_norm"
]:
with
set_current_vllm_config
(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
custom_op
]))
):
try
:
vllm_fused_allreduce
=
VllmFusedAllreduce
(
hidden_dim
,
dtype
)
suffix
=
(
"_custom_rms_norm"
if
"+"
in
custom_op
else
"_native_rms_norm"
)
...
...
@@ -461,6 +478,7 @@ def run_benchmarks(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
"-rms_norm"
]))
):
try
:
vllm_fused_allreduce
=
VllmFusedAllreduce
(
hidden_dim
,
dtype
)
standard_allreduce_rmsnorm_native_compiled
=
torch
.
compile
(
vllm_fused_allreduce
.
allreduce_rmsnorm
,
fullgraph
=
True
,
...
...
@@ -476,10 +494,11 @@ def run_benchmarks(
logger
.
error
(
"Standard AllReduce+RMSNorm Native Compiled failed: %s"
,
e
)
results
[
"standard_allreduce_rmsnorm_native_compiled"
]
=
float
(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm
Oneshot/Twoshot
if
flashinfer_comm
is
not
None
and
allreduce_params
is
not
None
:
# FlashInfer Fused AllReduce + RMSNorm
(all backends)
for
backend
,
workspace
in
workspaces
.
items
()
:
for
use_oneshot
in
use_oneshot_options
:
suffix
=
"_oneshot"
if
use_oneshot
else
"_twoshot"
key
=
f
"flashinfer_
{
backend
}
_fused_allreduce_rmsnorm
{
suffix
}
"
try
:
time_ms
=
benchmark_operation
(
flashinfer_fused_allreduce_rmsnorm
,
...
...
@@ -489,14 +508,17 @@ def run_benchmarks(
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
allreduce_params
=
allreduce_params
,
workspace
=
workspace
,
use_oneshot
=
use_oneshot
,
)
results
[
f
"flashinfer_fused_allreduce_rmsnorm
{
suffix
}
"
]
=
time_ms
results
[
key
]
=
time_ms
except
Exception
as
e
:
logger
.
error
(
"FlashInfer Fused AllReduce+RMSNorm failed: %s"
,
e
)
results
[
f
"flashinfer_fused_allreduce_rmsnorm
{
suffix
}
"
]
=
float
(
"inf"
logger
.
error
(
"FlashInfer (%s) Fused AllReduce+RMSNorm failed: %s"
,
backend
,
e
,
)
results
[
key
]
=
float
(
"inf"
)
if
"fp8"
in
quant_modes
:
# Standard AllReduce + RMSNorm + FP8 Quant
...
...
@@ -505,7 +527,7 @@ def run_benchmarks(
"_custom_rms_norm"
if
"+"
in
rms_norm_custom_op
else
"_native_rms_norm"
)
for
quant_fp8_custom_op
in
[
"-quant_fp8"
,
"+quant_fp8"
]:
suffix
+
=
(
op_suffix
=
suffix
+
(
"_custom_quant_fp8"
if
"+"
in
quant_fp8_custom_op
else
"_native_quant_fp8"
...
...
@@ -518,16 +540,17 @@ def run_benchmarks(
)
):
try
:
vllm_fused_allreduce
=
VllmFusedAllreduce
(
hidden_dim
,
dtype
)
time_ms
=
benchmark_operation
(
vllm_fused_allreduce
.
allreduce_rmsnorm_fp8_quant
,
input_tensor
,
residual
=
residual
,
scale_factor
=
scale_fp8
,
)
results
[
f
"standard_allreduce
{
suffix
}
"
]
=
time_ms
results
[
f
"standard_allreduce
{
op_
suffix
}
"
]
=
time_ms
except
Exception
as
e
:
logger
.
error
(
"Standard AllReduce+RMSNorm+FP8 failed: %s"
,
e
)
results
[
f
"standard_allreduce
{
suffix
}
"
]
=
float
(
"inf"
)
results
[
f
"standard_allreduce
{
op_
suffix
}
"
]
=
float
(
"inf"
)
# Standard AllReduce + RMSNorm + FP8 Quant Native Compiled
with
set_current_vllm_config
(
...
...
@@ -538,6 +561,7 @@ def run_benchmarks(
)
):
try
:
vllm_fused_allreduce
=
VllmFusedAllreduce
(
hidden_dim
,
dtype
)
standard_allreduce_rmsnorm_fp8_quant_native_compiled
=
torch
.
compile
(
vllm_fused_allreduce
.
allreduce_rmsnorm_fp8_quant
,
fullgraph
=
True
,
...
...
@@ -560,10 +584,12 @@ def run_benchmarks(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot
if
flashinfer_comm
is
not
None
and
allreduce_params
is
not
None
:
# FlashInfer Fused AllReduce + RMSNorm + FP8 Quant (trtllm only)
if
"trtllm"
in
workspaces
:
trtllm_ws
=
workspaces
[
"trtllm"
]
for
use_oneshot
in
use_oneshot_options
:
suffix
=
"_oneshot"
if
use_oneshot
else
"_twoshot"
key
=
f
"flashinfer_trtllm_fused_allreduce_rmsnorm_fp8_quant
{
suffix
}
"
try
:
time_ms
=
benchmark_operation
(
flashinfer_fused_allreduce_rmsnorm_fp8_quant
,
...
...
@@ -575,19 +601,16 @@ def run_benchmarks(
scale_factor
=
scale_fp8
,
quant_out
=
quant_out_fp8
,
allreduce_params
=
allreduce_params
,
workspace
=
trtllm_ws
,
use_oneshot
=
use_oneshot
,
)
results
[
f
"flashinfer_fused_allreduce_rmsnorm_fp8_quant
{
suffix
}
"
]
=
(
time_ms
)
results
[
key
]
=
time_ms
except
Exception
as
e
:
logger
.
error
(
"FlashInfer Fused AllReduce+RMSNorm+FP8
Oneshot
failed: %s"
,
"FlashInfer
(trtllm)
Fused AllReduce+RMSNorm+FP8 failed: %s"
,
e
,
)
results
[
f
"flashinfer_fused_allreduce_rmsnorm_fp8_quant
{
suffix
}
"
]
=
(
float
(
"inf"
)
)
results
[
key
]
=
float
(
"inf"
)
if
"fp4"
in
quant_modes
and
current_platform
.
has_device_capability
(
100
):
# Standard AllReduce + RMSNorm + FP4 Quant
...
...
@@ -603,6 +626,7 @@ def run_benchmarks(
)
):
try
:
vllm_fused_allreduce
=
VllmFusedAllreduce
(
hidden_dim
,
dtype
)
time_ms
=
benchmark_operation
(
vllm_fused_allreduce
.
allreduce_rmsnorm_fp4_quant
,
input_tensor
,
...
...
@@ -621,6 +645,7 @@ def run_benchmarks(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
"-rms_norm"
]))
):
try
:
vllm_fused_allreduce
=
VllmFusedAllreduce
(
hidden_dim
,
dtype
)
standard_allreduce_rmsnorm_fp4_quant_native_compiled
=
torch
.
compile
(
vllm_fused_allreduce
.
allreduce_rmsnorm_fp4_quant
,
fullgraph
=
True
,
...
...
@@ -645,10 +670,12 @@ def run_benchmarks(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot
if
flashinfer_comm
is
not
None
and
allreduce_params
is
not
None
:
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant (trtllm only)
if
"trtllm"
in
workspaces
:
trtllm_ws
=
workspaces
[
"trtllm"
]
for
use_oneshot
in
use_oneshot_options
:
suffix
=
"_oneshot"
if
use_oneshot
else
"_twoshot"
key
=
f
"flashinfer_trtllm_fused_allreduce_rmsnorm_fp4_quant
{
suffix
}
"
try
:
time_ms
=
benchmark_operation
(
flashinfer_fused_allreduce_rmsnorm_fp4_quant
,
...
...
@@ -659,49 +686,18 @@ def run_benchmarks(
rms_eps
=
rms_eps
,
input_global_scale
=
scale_fp4
,
allreduce_params
=
allreduce_params
,
workspace
=
trtllm_ws
,
quant_out
=
fp4_quant_out
,
output_scale
=
fp4_output_scale
,
use_oneshot
=
use_oneshot
,
)
results
[
f
"flashinfer_fused_allreduce_rmsnorm_fp4_quant
{
suffix
}
"
]
=
(
time_ms
)
results
[
key
]
=
time_ms
except
Exception
as
e
:
logger
.
error
(
"FlashInfer Fused AllReduce+RMSNorm+FP4
Oneshot
failed: %s"
,
"FlashInfer
(trtllm)
Fused AllReduce+RMSNorm+FP4 failed: %s"
,
e
,
)
results
[
f
"flashinfer_fused_allreduce_rmsnorm_fp4_quant
{
suffix
}
"
]
=
(
float
(
"inf"
)
)
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Two-shot
if
flashinfer_comm
is
not
None
and
allreduce_params
is
not
None
:
try
:
time_ms
=
benchmark_operation
(
flashinfer_fused_allreduce_rmsnorm_fp4_quant
,
input_tensor
,
residual
=
residual
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
input_global_scale
=
scale_fp4
,
allreduce_params
=
allreduce_params
,
quant_out
=
fp4_quant_out
,
output_scale
=
fp4_output_scale
,
use_oneshot
=
False
,
)
results
[
"flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"
]
=
(
time_ms
)
except
Exception
as
e
:
logger
.
error
(
"FlashInfer Fused AllReduce+RMSNorm+FP4 Two-shot failed: %s"
,
e
,
)
results
[
"flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"
]
=
float
(
"inf"
)
results
[
key
]
=
float
(
"inf"
)
return
results
...
...
@@ -988,7 +984,7 @@ def main():
world_size
=
int
(
os
.
environ
[
"WORLD_SIZE"
])
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
torch
.
accelerator
.
set_device
_index
(
device
)
torch
.
set_default_device
(
device
)
init_distributed_environment
()
...
...
@@ -1039,24 +1035,33 @@ def main():
configs
=
list
(
itertools
.
product
(
args
.
num_tokens
,
dtypes
,
residual_options
))
# Setup FlashInfer workspace if available
ipc_handles
=
None
# Setup FlashInfer workspaces for all backends
allreduce_params
=
None
if
flashinfer_comm
is
not
None
:
# Use the largest hidden dimension for workspace setup
max_element_size
=
max
(
torch
.
finfo
(
dt
).
bits
//
8
for
dt
in
dtypes
)
workspace_dtype
=
(
torch
.
float32
if
max_element_size
==
4
else
(
torch
.
bfloat16
if
torch
.
bfloat16
in
dtypes
else
torch
.
float16
)
)
max_num_token
=
_FI_MAX_SIZES
.
get
(
world_size
)
//
(
args
.
hidden_dim
*
world
_size
*
2
args
.
hidden_dim
*
max_element
_size
)
ipc_handles
,
workspace_tensor
=
setup_flashinfer_workspace
(
world_size
,
rank
,
args
.
hidden_dim
,
max_num_token
)
for
backend
in
FLASHINFER_BACKENDS
:
setup_flashinfer_workspace
(
backend
=
backend
,
world_size
=
world_size
,
rank
=
rank
,
hidden_dim
=
args
.
hidden_dim
,
max_token_num
=
max_num_token
,
dtype
=
workspace_dtype
,
)
if
workspace_tensor
is
not
None
:
if
_FI_WORKSPACES
:
allreduce_params
=
FlashInferFusedAllReduceParams
(
rank
=
rank
,
world_size
=
world_size
,
max_token_num
=
max_num_token
,
)
...
...
@@ -1081,6 +1086,7 @@ def main():
dtype
,
use_residual
,
allreduce_params
,
workspaces
=
_FI_WORKSPACES
,
quant_modes
=
quant_modes
,
no_oneshot
=
args
.
no_oneshot
,
)
...
...
@@ -1119,11 +1125,13 @@ def main():
finally
:
# Cleanup
if
ipc_handles
is
not
None
:
cleanup_flashinfer_workspace
(
ipc_handles
)
cleanup_flashinfer_workspaces
()
dist
.
barrier
()
if
__name__
==
"__main__"
:
main
()
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
with
set_current_vllm_config
(
VllmConfig
()):
main
()
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
View file @
3fb4b5fa
...
...
@@ -9,15 +9,15 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from
tests.kernels.moe.utils
import
make_dummy_moe_config
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe.all2all_utils
import
(
maybe_make_prepare_finalize
,
)
from
vllm.model_executor.layers.fused_moe.config
import
fp8_w8a8_moe_quant_config
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
CutlassExpertsFp8
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
,
fused_topk
,
)
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
,
)
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.v1.worker.workspace
import
init_workspace_manager
...
...
@@ -50,7 +50,7 @@ def bench_run(
per_out_ch
:
bool
,
mkn
:
tuple
[
int
,
int
,
int
],
):
init_workspace_manager
(
torch
.
cuda
.
current_device
())
init_workspace_manager
(
torch
.
accelerator
.
current_device
_index
())
label
=
"Quant Matmul"
sub_label
=
(
...
...
@@ -131,16 +131,22 @@ def bench_run(
w2_scale
=
w2_scale
,
per_act_token_quant
=
per_act_token
,
)
moe_config
=
make_dummy_moe_config
(
num_experts
=
w2
.
shape
[
0
],
hidden_dim
=
w2
.
shape
[
1
],
intermediate_size_per_partition
=
w2
.
shape
[
2
],
in_dtype
=
a
.
dtype
,
)
fn
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
fn
=
mk
.
FusedMoEKernel
(
maybe_make_prepare_finalize
(
moe
=
moe_config
,
quant_config
=
quant_config
,
allow_new_interface
=
True
,
use_monolithic
=
False
,
),
CutlassExpertsFp8
(
moe_config
=
make_dummy_moe_config
(
num_experts
=
w2
.
shape
[
0
],
hidden_dim
=
w2
.
shape
[
1
],
intermediate_size_per_partition
=
w2
.
shape
[
2
],
in_dtype
=
a
.
dtype
,
),
moe_config
=
moe_config
,
quant_config
=
quant_config
,
),
)
...
...
@@ -163,16 +169,22 @@ def bench_run(
w2_scale
=
w2_scale
,
per_act_token_quant
=
per_act_token
,
)
moe_config
=
make_dummy_moe_config
(
num_experts
=
w2
.
shape
[
0
],
hidden_dim
=
w2
.
shape
[
1
],
intermediate_size_per_partition
=
w2
.
shape
[
2
],
in_dtype
=
a
.
dtype
,
)
fn
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
fn
=
mk
.
FusedMoEKernel
(
maybe_make_prepare_finalize
(
moe
=
moe_config
,
quant_config
=
quant_config
,
allow_new_interface
=
True
,
use_monolithic
=
False
,
),
CutlassExpertsFp8
(
moe_config
=
make_dummy_moe_config
(
num_experts
=
w2
.
shape
[
0
],
hidden_dim
=
w2
.
shape
[
1
],
intermediate_size_per_partition
=
w2
.
shape
[
2
],
in_dtype
=
a
.
dtype
,
),
moe_config
=
moe_config
,
quant_config
=
quant_config
,
),
)
...
...
@@ -212,7 +224,7 @@ def bench_run(
def
replay_graph
(
graph
,
num_repeats
):
for
_
in
range
(
num_repeats
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
torch
.
accelerator
.
synchronize
()
cutlass_stream
=
torch
.
cuda
.
Stream
()
cutlass_graph
=
torch
.
cuda
.
CUDAGraph
()
...
...
@@ -227,7 +239,7 @@ def bench_run(
topk_weights
,
topk_ids
,
)
torch
.
cuda
.
synchronize
()
torch
.
accelerator
.
synchronize
()
triton_stream
=
torch
.
cuda
.
Stream
()
triton_graph
=
torch
.
cuda
.
CUDAGraph
()
...
...
@@ -242,7 +254,7 @@ def bench_run(
w2_scale
,
a_scale
,
)
torch
.
cuda
.
synchronize
()
torch
.
accelerator
.
synchronize
()
min_run_time
=
5
num_warmup
=
5
...
...
Prev
1
2
3
4
5
6
7
8
9
10
…
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment