Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
006693ed
Commit
006693ed
authored
Dec 01, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.11.2' into v0.11.2-ori
parents
4b51e6f1
275de341
Changes
544
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1616 additions
and
95 deletions
+1616
-95
benchmarks/benchmark_prefix_caching.py
benchmarks/benchmark_prefix_caching.py
+4
-5
benchmarks/benchmark_prioritization.py
benchmarks/benchmark_prioritization.py
+2
-3
benchmarks/benchmark_serving_structured_output.py
benchmarks/benchmark_serving_structured_output.py
+10
-17
benchmarks/benchmark_utils.py
benchmarks/benchmark_utils.py
+8
-8
benchmarks/cutlass_benchmarks/sparse_benchmarks.py
benchmarks/cutlass_benchmarks/sparse_benchmarks.py
+2
-3
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
+9
-9
benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh
benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh
+2
-6
benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh
benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh
+4
-12
benchmarks/fused_kernels/layernorm_rms_benchmarks.py
benchmarks/fused_kernels/layernorm_rms_benchmarks.py
+4
-5
benchmarks/kernels/bench_block_fp8_gemm.py
benchmarks/kernels/bench_block_fp8_gemm.py
+29
-14
benchmarks/kernels/bench_mxfp4_qutlass.py
benchmarks/kernels/bench_mxfp4_qutlass.py
+191
-0
benchmarks/kernels/bench_nvfp4_qutlass.py
benchmarks/kernels/bench_nvfp4_qutlass.py
+207
-0
benchmarks/kernels/bench_per_token_quant_fp8.py
benchmarks/kernels/bench_per_token_quant_fp8.py
+3
-2
benchmarks/kernels/benchmark_activation.py
benchmarks/kernels/benchmark_activation.py
+2
-1
benchmarks/kernels/benchmark_bitblas.py
benchmarks/kernels/benchmark_bitblas.py
+1
-1
benchmarks/kernels/benchmark_cutlass_fp4_moe.py
benchmarks/kernels/benchmark_cutlass_fp4_moe.py
+1
-1
benchmarks/kernels/benchmark_cutlass_moe_fp8.py
benchmarks/kernels/benchmark_cutlass_moe_fp8.py
+1
-1
benchmarks/kernels/benchmark_device_communicators.py
benchmarks/kernels/benchmark_device_communicators.py
+4
-4
benchmarks/kernels/benchmark_fused_collective.py
benchmarks/kernels/benchmark_fused_collective.py
+1129
-0
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
+3
-3
No files found.
Too many changes to show.
To preserve performance only
544 of 544+
files are displayed.
Plain diff
Email patch
benchmarks/benchmark_prefix_caching.py
View file @
006693ed
...
@@ -32,13 +32,12 @@ import dataclasses
...
@@ -32,13 +32,12 @@ import dataclasses
import
json
import
json
import
random
import
random
import
time
import
time
from
typing
import
Optional
from
transformers
import
PreTrainedTokenizerBase
from
transformers
import
PreTrainedTokenizerBase
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
try
:
try
:
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
...
@@ -70,7 +69,7 @@ def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> list[int]:
...
@@ -70,7 +69,7 @@ def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> list[int]:
# Remove the special tokens.
# Remove the special tokens.
return
random
.
choices
(
return
random
.
choices
(
[
v
for
k
,
v
in
vocab
.
item
s
()
if
k
not
in
all_special_ids
],
[
v
for
v
in
vocab
.
value
s
()
if
v
not
in
all_special_ids
],
k
=
length
,
k
=
length
,
)
)
...
@@ -80,7 +79,7 @@ def sample_requests_from_dataset(
...
@@ -80,7 +79,7 @@ def sample_requests_from_dataset(
num_requests
:
int
,
num_requests
:
int
,
tokenizer
:
PreTrainedTokenizerBase
,
tokenizer
:
PreTrainedTokenizerBase
,
input_length_range
:
tuple
[
int
,
int
],
input_length_range
:
tuple
[
int
,
int
],
fixed_output_len
:
Optional
[
int
]
,
fixed_output_len
:
int
|
None
,
)
->
list
[
Request
]:
)
->
list
[
Request
]:
if
fixed_output_len
is
not
None
and
fixed_output_len
<
4
:
if
fixed_output_len
is
not
None
and
fixed_output_len
<
4
:
raise
ValueError
(
"output_len too small"
)
raise
ValueError
(
"output_len too small"
)
...
@@ -128,7 +127,7 @@ def sample_requests_from_random(
...
@@ -128,7 +127,7 @@ def sample_requests_from_random(
num_requests
:
int
,
num_requests
:
int
,
tokenizer
:
PreTrainedTokenizerBase
,
tokenizer
:
PreTrainedTokenizerBase
,
input_length_range
:
tuple
[
int
,
int
],
input_length_range
:
tuple
[
int
,
int
],
fixed_output_len
:
Optional
[
int
]
,
fixed_output_len
:
int
|
None
,
prefix_len
:
int
,
prefix_len
:
int
,
)
->
list
[
Request
]:
)
->
list
[
Request
]:
requests
=
[]
requests
=
[]
...
...
benchmarks/benchmark_prioritization.py
View file @
006693ed
...
@@ -7,12 +7,11 @@ import dataclasses
...
@@ -7,12 +7,11 @@ import dataclasses
import
json
import
json
import
random
import
random
import
time
import
time
from
typing
import
Optional
from
transformers
import
AutoTokenizer
,
PreTrainedTokenizerBase
from
transformers
import
AutoTokenizer
,
PreTrainedTokenizerBase
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
# Select a equi-probable random priority
# Select a equi-probable random priority
...
@@ -24,7 +23,7 @@ def sample_requests(
...
@@ -24,7 +23,7 @@ def sample_requests(
dataset_path
:
str
,
dataset_path
:
str
,
num_requests
:
int
,
num_requests
:
int
,
tokenizer
:
PreTrainedTokenizerBase
,
tokenizer
:
PreTrainedTokenizerBase
,
fixed_output_len
:
Optional
[
int
]
,
fixed_output_len
:
int
|
None
,
)
->
list
[
tuple
[
str
,
int
,
int
,
int
]]:
)
->
list
[
tuple
[
str
,
int
,
int
,
int
]]:
if
fixed_output_len
is
not
None
and
fixed_output_len
<
4
:
if
fixed_output_len
is
not
None
and
fixed_output_len
<
4
:
raise
ValueError
(
"output_len too small"
)
raise
ValueError
(
"output_len too small"
)
...
...
benchmarks/benchmark_serving_structured_output.py
View file @
006693ed
...
@@ -31,20 +31,19 @@ import time
...
@@ -31,20 +31,19 @@ import time
import
uuid
import
uuid
import
warnings
import
warnings
from
collections.abc
import
AsyncGenerator
from
collections.abc
import
AsyncGenerator
from
contextlib
import
nullcontext
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
import
datasets
import
datasets
import
numpy
as
np
import
numpy
as
np
import
pandas
as
pd
import
pandas
as
pd
from
tqdm.asyncio
import
tqdm
from
transformers
import
PreTrainedTokenizerBase
from
backend_request_func
import
(
from
backend_request_func
import
(
ASYNC_REQUEST_FUNCS
,
ASYNC_REQUEST_FUNCS
,
RequestFuncInput
,
RequestFuncInput
,
RequestFuncOutput
,
RequestFuncOutput
,
)
)
from
tqdm.asyncio
import
tqdm
from
transformers
import
PreTrainedTokenizerBase
try
:
try
:
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
...
@@ -52,7 +51,7 @@ except ImportError:
...
@@ -52,7 +51,7 @@ except ImportError:
from
backend_request_func
import
get_tokenizer
from
backend_request_func
import
get_tokenizer
try
:
try
:
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
except
ImportError
:
except
ImportError
:
from
argparse
import
ArgumentParser
as
FlexibleArgumentParser
from
argparse
import
ArgumentParser
as
FlexibleArgumentParser
...
@@ -317,7 +316,7 @@ def calculate_metrics(
...
@@ -317,7 +316,7 @@ def calculate_metrics(
tokenizer
:
PreTrainedTokenizerBase
,
tokenizer
:
PreTrainedTokenizerBase
,
selected_percentile_metrics
:
list
[
str
],
selected_percentile_metrics
:
list
[
str
],
selected_percentiles
:
list
[
float
],
selected_percentiles
:
list
[
float
],
goodput_config_dict
:
Optional
[
dict
[
str
,
float
]
]
=
None
,
goodput_config_dict
:
dict
[
str
,
float
]
|
None
=
None
,
)
->
tuple
[
BenchmarkMetrics
,
list
[
int
]]:
)
->
tuple
[
BenchmarkMetrics
,
list
[
int
]]:
actual_output_lens
:
list
[
int
]
=
[]
actual_output_lens
:
list
[
int
]
=
[]
total_input
=
0
total_input
=
0
...
@@ -437,9 +436,9 @@ async def benchmark(
...
@@ -437,9 +436,9 @@ async def benchmark(
selected_percentile_metrics
:
list
[
str
],
selected_percentile_metrics
:
list
[
str
],
selected_percentiles
:
list
[
str
],
selected_percentiles
:
list
[
str
],
ignore_eos
:
bool
,
ignore_eos
:
bool
,
max_concurrency
:
Optional
[
int
]
,
max_concurrency
:
int
|
None
,
structured_output_ratio
:
float
,
structured_output_ratio
:
float
,
goodput_config_dict
:
Optional
[
dict
[
str
,
float
]
]
=
None
,
goodput_config_dict
:
dict
[
str
,
float
]
|
None
=
None
,
):
):
if
backend
in
ASYNC_REQUEST_FUNCS
:
if
backend
in
ASYNC_REQUEST_FUNCS
:
request_func
=
ASYNC_REQUEST_FUNCS
[
backend
]
request_func
=
ASYNC_REQUEST_FUNCS
[
backend
]
...
@@ -503,15 +502,9 @@ async def benchmark(
...
@@ -503,15 +502,9 @@ async def benchmark(
pbar
=
None
if
disable_tqdm
else
tqdm
(
total
=
len
(
input_requests
))
pbar
=
None
if
disable_tqdm
else
tqdm
(
total
=
len
(
input_requests
))
# This can be used once the minimum Python version is 3.10 or higher,
semaphore
=
asyncio
.
Semaphore
(
max_concurrency
)
if
max_concurrency
else
nullcontext
()
# and it will simplify the code in limited_request_func.
# semaphore = (asyncio.Semaphore(max_concurrency)
# if max_concurrency else contextlib.nullcontext())
semaphore
=
asyncio
.
Semaphore
(
max_concurrency
)
if
max_concurrency
else
None
async
def
limited_request_func
(
request_func_input
,
pbar
):
async
def
limited_request_func
(
request_func_input
,
pbar
):
if
semaphore
is
None
:
return
await
request_func
(
request_func_input
=
request_func_input
,
pbar
=
pbar
)
async
with
semaphore
:
async
with
semaphore
:
return
await
request_func
(
request_func_input
=
request_func_input
,
pbar
=
pbar
)
return
await
request_func
(
request_func_input
=
request_func_input
,
pbar
=
pbar
)
...
@@ -910,13 +903,13 @@ def create_argument_parser():
...
@@ -910,13 +903,13 @@ def create_argument_parser():
parser
.
add_argument
(
parser
.
add_argument
(
"--tokenizer"
,
"--tokenizer"
,
type
=
str
,
type
=
str
,
help
=
"Name or path of the tokenizer, if not using the default tokenizer."
,
# noqa: E501
help
=
"Name or path of the tokenizer, if not using the default tokenizer."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--tokenizer-mode"
,
"--tokenizer-mode"
,
type
=
str
,
type
=
str
,
default
=
"auto"
,
default
=
"auto"
,
help
=
"Name or path of the tokenizer, if not using the default tokenizer."
,
# noqa: E501
help
=
"Name or path of the tokenizer, if not using the default tokenizer."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--num-prompts"
,
"--num-prompts"
,
...
...
benchmarks/benchmark_utils.py
View file @
006693ed
...
@@ -6,7 +6,7 @@ import math
...
@@ -6,7 +6,7 @@ import math
import
os
import
os
import
time
import
time
from
types
import
TracebackType
from
types
import
TracebackType
from
typing
import
Any
,
Optional
,
Union
from
typing
import
Any
def
convert_to_pytorch_benchmark_format
(
def
convert_to_pytorch_benchmark_format
(
...
@@ -92,7 +92,7 @@ class TimeCollector:
...
@@ -92,7 +92,7 @@ class TimeCollector:
def
__init__
(
self
,
scale
:
int
)
->
None
:
def
__init__
(
self
,
scale
:
int
)
->
None
:
self
.
cnt
:
int
=
0
self
.
cnt
:
int
=
0
self
.
_sum
:
int
=
0
self
.
_sum
:
int
=
0
self
.
_max
:
Optional
[
int
]
=
None
self
.
_max
:
int
|
None
=
None
self
.
scale
=
scale
self
.
scale
=
scale
self
.
start_time
:
int
=
time
.
monotonic_ns
()
self
.
start_time
:
int
=
time
.
monotonic_ns
()
...
@@ -104,13 +104,13 @@ class TimeCollector:
...
@@ -104,13 +104,13 @@ class TimeCollector:
else
:
else
:
self
.
_max
=
max
(
self
.
_max
,
v
)
self
.
_max
=
max
(
self
.
_max
,
v
)
def
avg
(
self
)
->
Union
[
float
,
str
]
:
def
avg
(
self
)
->
float
|
str
:
return
self
.
_sum
*
1.0
/
self
.
cnt
/
self
.
scale
if
self
.
cnt
>
0
else
"N/A"
return
self
.
_sum
*
1.0
/
self
.
cnt
/
self
.
scale
if
self
.
cnt
>
0
else
"N/A"
def
max
(
self
)
->
Union
[
float
,
str
]
:
def
max
(
self
)
->
float
|
str
:
return
self
.
_max
/
self
.
scale
if
self
.
_max
else
"N/A"
return
self
.
_max
/
self
.
scale
if
self
.
_max
else
"N/A"
def
dump_avg_max
(
self
)
->
list
[
Union
[
float
,
str
]
]
:
def
dump_avg_max
(
self
)
->
list
[
float
|
str
]:
return
[
self
.
avg
(),
self
.
max
()]
return
[
self
.
avg
(),
self
.
max
()]
def
__enter__
(
self
)
->
None
:
def
__enter__
(
self
)
->
None
:
...
@@ -118,8 +118,8 @@ class TimeCollector:
...
@@ -118,8 +118,8 @@ class TimeCollector:
def
__exit__
(
def
__exit__
(
self
,
self
,
exc_type
:
Optional
[
type
[
BaseException
]
]
,
exc_type
:
type
[
BaseException
]
|
None
,
exc_value
:
Optional
[
BaseException
]
,
exc_value
:
BaseException
|
None
,
exc_traceback
:
Optional
[
TracebackType
]
,
exc_traceback
:
TracebackType
|
None
,
)
->
None
:
)
->
None
:
self
.
collect
(
time
.
monotonic_ns
()
-
self
.
start_time
)
self
.
collect
(
time
.
monotonic_ns
()
-
self
.
start_time
)
benchmarks/cutlass_benchmarks/sparse_benchmarks.py
View file @
006693ed
...
@@ -6,8 +6,7 @@ import copy
...
@@ -6,8 +6,7 @@ import copy
import
itertools
import
itertools
import
pickle
as
pkl
import
pickle
as
pkl
import
time
import
time
from
collections.abc
import
Iterable
from
collections.abc
import
Callable
,
Iterable
from
typing
import
Callable
import
torch
import
torch
import
torch.utils.benchmark
as
TBenchmark
import
torch.utils.benchmark
as
TBenchmark
...
@@ -16,7 +15,7 @@ from utils import make_rand_sparse_tensors
...
@@ -16,7 +15,7 @@ from utils import make_rand_sparse_tensors
from
weight_shapes
import
WEIGHT_SHAPES
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
list
(
WEIGHT_SHAPES
.
keys
())
DEFAULT_MODELS
=
list
(
WEIGHT_SHAPES
.
keys
())
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
]
...
...
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
View file @
006693ed
...
@@ -6,8 +6,7 @@ import copy
...
@@ -6,8 +6,7 @@ import copy
import
itertools
import
itertools
import
pickle
as
pkl
import
pickle
as
pkl
import
time
import
time
from
collections.abc
import
Iterable
from
collections.abc
import
Callable
,
Iterable
from
typing
import
Callable
,
Optional
import
torch
import
torch
import
torch.utils.benchmark
as
TBenchmark
import
torch.utils.benchmark
as
TBenchmark
...
@@ -17,9 +16,10 @@ from weight_shapes import WEIGHT_SHAPES
...
@@ -17,9 +16,10 @@ from weight_shapes import WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
w8a8_
block_fp8_matmul
,
w8a8_
triton_block_scaled_mm
,
)
)
from
vllm.utils
import
FlexibleArgumentParser
,
cdiv
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.math_utils
import
cdiv
DEFAULT_MODELS
=
list
(
WEIGHT_SHAPES
.
keys
())
DEFAULT_MODELS
=
list
(
WEIGHT_SHAPES
.
keys
())
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
]
...
@@ -53,7 +53,7 @@ def bench_int8(
...
@@ -53,7 +53,7 @@ def bench_int8(
n
:
int
,
n
:
int
,
label
:
str
,
label
:
str
,
sub_label
:
str
,
sub_label
:
str
,
bench_kernels
:
Optional
[
list
[
str
]
]
=
None
,
bench_kernels
:
list
[
str
]
|
None
=
None
,
)
->
Iterable
[
TMeasurement
]:
)
->
Iterable
[
TMeasurement
]:
"""Benchmark INT8-based kernels."""
"""Benchmark INT8-based kernels."""
assert
dtype
==
torch
.
int8
assert
dtype
==
torch
.
int8
...
@@ -108,7 +108,7 @@ def bench_fp8(
...
@@ -108,7 +108,7 @@ def bench_fp8(
n
:
int
,
n
:
int
,
label
:
str
,
label
:
str
,
sub_label
:
str
,
sub_label
:
str
,
bench_kernels
:
Optional
[
list
[
str
]
]
=
None
,
bench_kernels
:
list
[
str
]
|
None
=
None
,
)
->
Iterable
[
TMeasurement
]:
)
->
Iterable
[
TMeasurement
]:
"""Benchmark FP8-based kernels."""
"""Benchmark FP8-based kernels."""
assert
dtype
==
torch
.
float8_e4m3fn
assert
dtype
==
torch
.
float8_e4m3fn
...
@@ -158,7 +158,7 @@ def bench_fp8(
...
@@ -158,7 +158,7 @@ def bench_fp8(
"cutlass_fp8_fp8_fp16_scaled_mm_bias"
:
lambda
:
ops
.
cutlass_scaled_mm
(
"cutlass_fp8_fp8_fp16_scaled_mm_bias"
:
lambda
:
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
,
bias
.
to
(
dtype
=
torch
.
float16
)
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
,
bias
.
to
(
dtype
=
torch
.
float16
)
),
),
"triton_fp8_fp8_fp16_scaled_mm_blockwise"
:
lambda
:
w8a8_
block_fp8_matmul
(
"triton_fp8_fp8_fp16_scaled_mm_blockwise"
:
lambda
:
w8a8_
triton_block_scaled_mm
(
a_cont
,
b
.
t
(),
block_scale_a
,
block_scale_b
.
t
(),
(
128
,
128
)
a_cont
,
b
.
t
(),
block_scale_a
,
block_scale_b
.
t
(),
(
128
,
128
)
),
),
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise"
:
lambda
:
ops
.
cutlass_scaled_mm
(
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise"
:
lambda
:
ops
.
cutlass_scaled_mm
(
...
@@ -183,7 +183,7 @@ def bench(
...
@@ -183,7 +183,7 @@ def bench(
n
:
int
,
n
:
int
,
label
:
str
,
label
:
str
,
sub_label
:
str
,
sub_label
:
str
,
bench_kernels
:
Optional
[
list
[
str
]
]
=
None
,
bench_kernels
:
list
[
str
]
|
None
=
None
,
)
->
Iterable
[
TMeasurement
]:
)
->
Iterable
[
TMeasurement
]:
if
dtype
==
torch
.
int8
:
if
dtype
==
torch
.
int8
:
return
bench_int8
(
dtype
,
m
,
k
,
n
,
label
,
sub_label
,
bench_kernels
)
return
bench_int8
(
dtype
,
m
,
k
,
n
,
label
,
sub_label
,
bench_kernels
)
...
@@ -201,7 +201,7 @@ def print_timers(timers: Iterable[TMeasurement]):
...
@@ -201,7 +201,7 @@ def print_timers(timers: Iterable[TMeasurement]):
def
run
(
def
run
(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
MKNs
:
Iterable
[
tuple
[
int
,
int
,
int
]],
MKNs
:
Iterable
[
tuple
[
int
,
int
,
int
]],
bench_kernels
:
Optional
[
list
[
str
]
]
=
None
,
bench_kernels
:
list
[
str
]
|
None
=
None
,
)
->
Iterable
[
TMeasurement
]:
)
->
Iterable
[
TMeasurement
]:
results
=
[]
results
=
[]
for
m
,
k
,
n
in
MKNs
:
for
m
,
k
,
n
in
MKNs
:
...
...
benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh
View file @
006693ed
...
@@ -55,9 +55,7 @@ benchmark() {
...
@@ -55,9 +55,7 @@ benchmark() {
output_len
=
$2
output_len
=
$2
CUDA_VISIBLE_DEVICES
=
0 python3
\
CUDA_VISIBLE_DEVICES
=
0 vllm serve
$model
\
-m
vllm.entrypoints.openai.api_server
\
--model
$model
\
--port
8100
\
--port
8100
\
--max-model-len
10000
\
--max-model-len
10000
\
--gpu-memory-utilization
0.6
\
--gpu-memory-utilization
0.6
\
...
@@ -65,9 +63,7 @@ benchmark() {
...
@@ -65,9 +63,7 @@ benchmark() {
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}'
&
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}'
&
CUDA_VISIBLE_DEVICES
=
1 python3
\
CUDA_VISIBLE_DEVICES
=
1 vllm serve
$model
\
-m
vllm.entrypoints.openai.api_server
\
--model
$model
\
--port
8200
\
--port
8200
\
--max-model-len
10000
\
--max-model-len
10000
\
--gpu-memory-utilization
0.6
\
--gpu-memory-utilization
0.6
\
...
...
benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh
View file @
006693ed
...
@@ -38,16 +38,12 @@ wait_for_server() {
...
@@ -38,16 +38,12 @@ wait_for_server() {
launch_chunked_prefill
()
{
launch_chunked_prefill
()
{
model
=
"meta-llama/Meta-Llama-3.1-8B-Instruct"
model
=
"meta-llama/Meta-Llama-3.1-8B-Instruct"
# disagg prefill
# disagg prefill
CUDA_VISIBLE_DEVICES
=
0 python3
\
CUDA_VISIBLE_DEVICES
=
0 vllm serve
$model
\
-m
vllm.entrypoints.openai.api_server
\
--model
$model
\
--port
8100
\
--port
8100
\
--max-model-len
10000
\
--max-model-len
10000
\
--enable-chunked-prefill
\
--enable-chunked-prefill
\
--gpu-memory-utilization
0.6 &
--gpu-memory-utilization
0.6 &
CUDA_VISIBLE_DEVICES
=
1 python3
\
CUDA_VISIBLE_DEVICES
=
1 vllm serve
$model
\
-m
vllm.entrypoints.openai.api_server
\
--model
$model
\
--port
8200
\
--port
8200
\
--max-model-len
10000
\
--max-model-len
10000
\
--enable-chunked-prefill
\
--enable-chunked-prefill
\
...
@@ -62,18 +58,14 @@ launch_chunked_prefill() {
...
@@ -62,18 +58,14 @@ launch_chunked_prefill() {
launch_disagg_prefill
()
{
launch_disagg_prefill
()
{
model
=
"meta-llama/Meta-Llama-3.1-8B-Instruct"
model
=
"meta-llama/Meta-Llama-3.1-8B-Instruct"
# disagg prefill
# disagg prefill
CUDA_VISIBLE_DEVICES
=
0 python3
\
CUDA_VISIBLE_DEVICES
=
0 vllm serve
$model
\
-m
vllm.entrypoints.openai.api_server
\
--model
$model
\
--port
8100
\
--port
8100
\
--max-model-len
10000
\
--max-model-len
10000
\
--gpu-memory-utilization
0.6
\
--gpu-memory-utilization
0.6
\
--kv-transfer-config
\
--kv-transfer-config
\
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}'
&
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}'
&
CUDA_VISIBLE_DEVICES
=
1 python3
\
CUDA_VISIBLE_DEVICES
=
1 vllm serve
$model
\
-m
vllm.entrypoints.openai.api_server
\
--model
$model
\
--port
8200
\
--port
8200
\
--max-model-len
10000
\
--max-model-len
10000
\
--gpu-memory-utilization
0.6
\
--gpu-memory-utilization
0.6
\
...
...
benchmarks/fused_kernels/layernorm_rms_benchmarks.py
View file @
006693ed
...
@@ -3,10 +3,9 @@
...
@@ -3,10 +3,9 @@
import
pickle
as
pkl
import
pickle
as
pkl
import
time
import
time
from
collections.abc
import
Iterable
from
collections.abc
import
Callable
,
Iterable
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
itertools
import
product
from
itertools
import
product
from
typing
import
Callable
,
Optional
import
torch
import
torch
import
torch.utils.benchmark
as
TBenchmark
import
torch.utils.benchmark
as
TBenchmark
...
@@ -51,7 +50,7 @@ def get_bench_params() -> list[bench_params_t]:
...
@@ -51,7 +50,7 @@ def get_bench_params() -> list[bench_params_t]:
def
unfused_int8_impl
(
def
unfused_int8_impl
(
rms_norm_layer
:
RMSNorm
,
rms_norm_layer
:
RMSNorm
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
,
residual
:
torch
.
Tensor
|
None
,
quant_dtype
:
torch
.
dtype
,
quant_dtype
:
torch
.
dtype
,
):
):
# Norm
# Norm
...
@@ -68,7 +67,7 @@ def unfused_int8_impl(
...
@@ -68,7 +67,7 @@ def unfused_int8_impl(
def
unfused_fp8_impl
(
def
unfused_fp8_impl
(
rms_norm_layer
:
RMSNorm
,
rms_norm_layer
:
RMSNorm
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
,
residual
:
torch
.
Tensor
|
None
,
quant_dtype
:
torch
.
dtype
,
quant_dtype
:
torch
.
dtype
,
):
):
# Norm
# Norm
...
@@ -85,7 +84,7 @@ def unfused_fp8_impl(
...
@@ -85,7 +84,7 @@ def unfused_fp8_impl(
def
fused_impl
(
def
fused_impl
(
rms_norm_layer
:
RMSNorm
,
# this stores the weights
rms_norm_layer
:
RMSNorm
,
# this stores the weights
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
,
residual
:
torch
.
Tensor
|
None
,
quant_dtype
:
torch
.
dtype
,
quant_dtype
:
torch
.
dtype
,
):
):
out
,
_
=
ops
.
rms_norm_dynamic_per_token_quant
(
out
,
_
=
ops
.
rms_norm_dynamic_per_token_quant
(
...
...
benchmarks/kernels/bench_block_fp8_gemm.py
View file @
006693ed
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
# Disable DeepGEMM for this benchmark to use CUTLASS
os
.
environ
[
"VLLM_USE_DEEP_GEMM"
]
=
"0"
import
torch
import
torch
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
apply_w8a8_block_fp8_linear
,
W8A8BlockFp8LinearOp
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
)
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
CUTLASS_BLOCK_FP8_SUPPORTED
,
CUTLASS_BLOCK_FP8_SUPPORTED
,
...
@@ -39,13 +47,14 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
...
@@ -39,13 +47,14 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
# Create random
FP8
tensor
s
# Create random
input
tensor
(bfloat16, will be quantized by W8A8BlockFp8LinearOp)
A_ref
=
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
A_ref
=
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
# Create quantized weight tensor
B_ref
=
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
B_ref
=
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
B
=
B_ref
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
B
=
B_ref
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
# Create scales
# Create
weight
scales
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
n_tiles
=
(
N
+
block_n
-
1
)
//
block_n
n_tiles
=
(
N
+
block_n
-
1
)
//
block_n
k_tiles
=
(
K
+
block_k
-
1
)
//
block_k
k_tiles
=
(
K
+
block_k
-
1
)
//
block_k
...
@@ -55,19 +64,25 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
...
@@ -55,19 +64,25 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
*
factor_for_scale
*
factor_for_scale
)
)
# SM90 CUTLASS requires row-major format for scales
# Create W8A8BlockFp8LinearOp instance
if
use_cutlass
and
current_platform
.
is_device_capability
(
90
):
weight_group_shape
=
GroupShape
(
block_n
,
block_k
)
Bs
=
Bs
.
T
.
contiguous
()
act_quant_group_shape
=
GroupShape
(
1
,
block_k
)
# Per-token, per-group quantization
linear_op
=
W8A8BlockFp8LinearOp
(
weight_group_shape
=
weight_group_shape
,
act_quant_group_shape
=
act_quant_group_shape
,
cutlass_block_fp8_supported
=
use_cutlass
,
use_aiter_and_is_supported
=
False
,
)
def
run
():
def
run
():
if
use_cutlass
:
return
linear_op
.
apply
(
return
apply_w8a8_block_fp8_linear
(
input
=
A_ref
,
A_ref
,
B
,
block_size
,
Bs
,
cutlass_block_fp8_supported
=
True
weight
=
B
,
)
weight_scale
=
Bs
,
else
:
input_scale
=
None
,
return
apply_w8a8_block_fp8_linear
(
bias
=
None
,
A_ref
,
B
,
block_size
,
Bs
,
cutlass_block_fp8_supported
=
False
)
)
return
run
return
run
...
...
benchmarks/kernels/bench_mxfp4_qutlass.py
0 → 100644
View file @
006693ed
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import
argparse
import
copy
import
itertools
import
torch
from
compressed_tensors.transform.utils.hadamard
import
deterministic_hadamard_matrix
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm._custom_ops
import
fusedQuantizeMx
,
matmul_mxf4_bf16_tn
from
vllm.model_executor.layers.quantization.qutlass_utils
import
to_blocked
from
vllm.triton_utils
import
triton
PROVIDER_CFGS
=
{
"torch-bf16"
:
dict
(
enabled
=
True
),
"mxfp4"
:
dict
(
no_a_quant
=
False
,
enabled
=
True
),
"mxfp4-noquant"
:
dict
(
no_a_quant
=
True
,
enabled
=
True
),
}
_enabled
=
[
k
for
k
,
v
in
PROVIDER_CFGS
.
items
()
if
v
[
"enabled"
]]
def
get_hadamard_matrix
(
group_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
):
return
(
deterministic_hadamard_matrix
(
group_size
,
dtype
=
dtype
,
device
=
device
)
*
group_size
**-
0.5
)
def
_quant_weight_mxfp4
(
b
:
torch
.
Tensor
,
forward_hadamard_matrix
:
torch
.
Tensor
,
device
:
str
):
weight_hf_e2m1
,
weight_hf_e8m0
=
fusedQuantizeMx
(
b
,
forward_hadamard_matrix
,
method
=
"abs_max"
)
weight_hf_scale_block
=
to_blocked
(
weight_hf_e8m0
,
backend
=
"triton"
)
return
weight_hf_e2m1
,
weight_hf_scale_block
def
build_mxfp4_runner
(
cfg
,
a
,
b
,
forward_hadamard_matrix
,
dtype
,
device
):
weight_hf_e2m1
,
weight_hf_scale_block
=
_quant_weight_mxfp4
(
b
,
forward_hadamard_matrix
,
device
)
alpha
=
torch
.
tensor
([
1.0
],
device
=
"cuda"
)
if
cfg
[
"no_a_quant"
]:
# Pre-quantize activation
input_hf_e2m1
,
input_hf_e8m0
=
fusedQuantizeMx
(
a
,
forward_hadamard_matrix
,
method
=
"abs_max"
)
input_hf_scale_block
=
to_blocked
(
input_hf_e8m0
,
backend
=
"triton"
)
def
run
():
return
matmul_mxf4_bf16_tn
(
input_hf_e2m1
,
weight_hf_e2m1
,
input_hf_scale_block
,
weight_hf_scale_block
,
alpha
,
)
return
run
# Quantize activation on-the-fly
def
run
():
input_hf_e2m1
,
input_hf_e8m0
=
fusedQuantizeMx
(
a
,
forward_hadamard_matrix
,
method
=
"abs_max"
)
input_hf_scale_block
=
to_blocked
(
input_hf_e8m0
,
backend
=
"triton"
)
return
matmul_mxf4_bf16_tn
(
input_hf_e2m1
,
weight_hf_e2m1
,
input_hf_scale_block
,
weight_hf_scale_block
,
alpha
,
)
return
run
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
,
24576
,
32768
,
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
_enabled
,
line_names
=
_enabled
,
ylabel
=
"TFLOP/s (larger is better)"
,
plot_name
=
"BF16 vs MXFP4 GEMMs"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
N
,
K
,
had_size
):
M
=
batch_size
device
=
"cuda"
dtype
=
torch
.
bfloat16
a
=
torch
.
randn
((
M
,
K
),
device
=
device
,
dtype
=
dtype
)
b
=
torch
.
randn
((
N
,
K
),
device
=
device
,
dtype
=
dtype
)
forward_hadamard_matrix
=
get_hadamard_matrix
(
had_size
,
dtype
,
device
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"torch-bf16"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
torch
.
nn
.
functional
.
linear
(
a
,
b
),
rep
=
200
,
quantiles
=
quantiles
)
else
:
cfg
=
PROVIDER_CFGS
[
provider
]
run_quant
=
build_mxfp4_runner
(
cfg
,
a
,
b
,
forward_hadamard_matrix
,
dtype
,
device
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
run_quant
(),
rep
=
200
,
quantiles
=
quantiles
)
to_tflops
=
lambda
t_ms
:
(
2
*
M
*
N
*
K
)
*
1e-12
/
(
t_ms
*
1e-3
)
return
to_tflops
(
ms
),
to_tflops
(
max_ms
),
to_tflops
(
min_ms
)
def
prepare_shapes
(
args
):
out
=
[]
for
model
,
tp_size
in
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
):
for
KN
,
tp_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model
]):
KN
[
tp_dim
]
//=
tp_size
KN
.
append
(
model
)
out
.
append
(
KN
)
return
out
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
[
"meta-llama/Llama-3.3-70B-Instruct"
],
choices
=
list
(
WEIGHT_SHAPES
.
keys
()),
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
])
args
=
parser
.
parse_args
()
for
K
,
N
,
model
in
prepare_shapes
(
args
):
for
had_size
in
[
32
,
64
,
128
]:
print
(
f
"
{
model
}
, N=
{
N
}
K=
{
K
}
, HAD=
{
had_size
}
, BF16 vs MXFP4 GEMMs TFLOP/s:"
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
f
"bench_mxfp4_res_n
{
N
}
_k
{
K
}
"
,
N
=
N
,
K
=
K
,
had_size
=
had_size
,
)
print
(
"Benchmark finished!"
)
benchmarks/kernels/bench_nvfp4_qutlass.py
0 → 100644
View file @
006693ed
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import
argparse
import
copy
import
itertools
import
torch
from
compressed_tensors.transform.utils.hadamard
import
deterministic_hadamard_matrix
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
# use existing nvfp4 gemm in vllm
from
vllm._custom_ops
import
fusedQuantizeNv
from
vllm.model_executor.layers.quantization.qutlass_utils
import
to_blocked
from
vllm.triton_utils
import
triton
PROVIDER_CFGS
=
{
"torch-bf16"
:
dict
(
enabled
=
True
),
"nvfp4"
:
dict
(
no_a_quant
=
False
,
enabled
=
True
),
"nvfp4-noquant"
:
dict
(
no_a_quant
=
True
,
enabled
=
True
),
}
_enabled
=
[
k
for
k
,
v
in
PROVIDER_CFGS
.
items
()
if
v
[
"enabled"
]]
def
get_hadamard_matrix
(
group_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
):
return
(
deterministic_hadamard_matrix
(
group_size
,
dtype
=
dtype
,
device
=
device
)
*
group_size
**-
0.5
)
def
_quant_weight_nvfp4
(
b
:
torch
.
Tensor
,
forward_hadamard_matrix
:
torch
.
Tensor
,
global_scale
:
torch
.
Tensor
,
device
:
str
,
M
:
int
,
N
:
int
,
K
:
int
,
):
weight_hf_e2m1
,
weight_hf_e8m0
=
fusedQuantizeNv
(
b
,
forward_hadamard_matrix
,
global_scale
)
weight_hf_scale_block
=
to_blocked
(
weight_hf_e8m0
,
backend
=
"triton"
).
view
(
-
1
,
K
//
16
)
return
weight_hf_e2m1
,
weight_hf_scale_block
def
build_nvfp4_runner
(
cfg
,
a
,
b
,
forward_hadamard_matrix
,
dtype
,
device
,
M
,
N
,
K
):
alpha
=
torch
.
tensor
([
1.0
],
device
=
"cuda"
)
global_scale
=
torch
.
tensor
([
1.0
],
device
=
"cuda"
)
weight_hf_e2m1
,
weight_hf_scale_block
=
_quant_weight_nvfp4
(
b
,
forward_hadamard_matrix
,
global_scale
,
device
,
M
,
N
,
K
)
if
cfg
[
"no_a_quant"
]:
# Pre-quantize activation
input_hf_e2m1
,
input_hf_e8m0
=
fusedQuantizeNv
(
a
,
forward_hadamard_matrix
,
global_scale
)
input_hf_scale_block
=
to_blocked
(
input_hf_e8m0
,
backend
=
"triton"
).
view
(
-
1
,
K
//
16
)
def
run
():
return
ops
.
cutlass_scaled_fp4_mm
(
input_hf_e2m1
,
weight_hf_e2m1
,
input_hf_scale_block
,
weight_hf_scale_block
,
alpha
,
torch
.
bfloat16
,
)
return
run
# Quantize activation on-the-fly
def
run
():
input_hf_e2m1
,
input_hf_e8m0
=
fusedQuantizeNv
(
a
,
forward_hadamard_matrix
,
global_scale
)
input_hf_scale_block
=
to_blocked
(
input_hf_e8m0
,
backend
=
"triton"
).
view
(
-
1
,
K
//
16
)
return
ops
.
cutlass_scaled_fp4_mm
(
input_hf_e2m1
,
weight_hf_e2m1
,
input_hf_scale_block
,
weight_hf_scale_block
,
alpha
,
torch
.
bfloat16
,
)
return
run
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
,
24576
,
32768
,
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
_enabled
,
line_names
=
_enabled
,
ylabel
=
"TFLOP/s (larger is better)"
,
plot_name
=
"BF16 vs NVFP4 GEMMs"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
N
,
K
,
had_size
):
M
=
batch_size
device
=
"cuda"
dtype
=
torch
.
bfloat16
a
=
torch
.
randn
((
M
,
K
),
device
=
device
,
dtype
=
dtype
)
b
=
torch
.
randn
((
N
,
K
),
device
=
device
,
dtype
=
dtype
)
forward_hadamard_matrix
=
get_hadamard_matrix
(
had_size
,
dtype
,
device
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"torch-bf16"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
torch
.
nn
.
functional
.
linear
(
a
,
b
),
rep
=
200
,
quantiles
=
quantiles
)
else
:
cfg
=
PROVIDER_CFGS
[
provider
]
run_quant
=
build_nvfp4_runner
(
cfg
,
a
,
b
,
forward_hadamard_matrix
,
dtype
,
device
,
M
,
N
,
K
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
run_quant
(),
rep
=
200
,
quantiles
=
quantiles
)
to_tflops
=
lambda
t_ms
:
(
2
*
M
*
N
*
K
)
*
1e-12
/
(
t_ms
*
1e-3
)
return
to_tflops
(
ms
),
to_tflops
(
max_ms
),
to_tflops
(
min_ms
)
def
prepare_shapes
(
args
):
out
=
[]
for
model
,
tp_size
in
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
):
for
KN
,
tp_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model
]):
KN
[
tp_dim
]
//=
tp_size
KN
.
append
(
model
)
out
.
append
(
KN
)
return
out
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
[
"meta-llama/Llama-3.3-70B-Instruct"
],
choices
=
list
(
WEIGHT_SHAPES
.
keys
()),
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
])
args
=
parser
.
parse_args
()
for
K
,
N
,
model
in
prepare_shapes
(
args
):
for
had_size
in
[
16
,
32
,
64
,
128
]:
print
(
f
"
{
model
}
, N=
{
N
}
K=
{
K
}
, HAD=
{
had_size
}
, BF16 vs NVFP4 GEMMs TFLOP/s:"
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
f
"bench_nvfp4_res_n
{
N
}
_k
{
K
}
"
,
N
=
N
,
K
=
K
,
had_size
=
had_size
,
)
print
(
"Benchmark finished!"
)
benchmarks/kernels/bench_per_token_quant_fp8.py
View file @
006693ed
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
import
itertools
from
typing
import
Callable
from
collections.abc
import
Callable
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
pandas
as
pd
import
pandas
as
pd
...
@@ -10,7 +10,8 @@ import torch
...
@@ -10,7 +10,8 @@ import torch
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.triton_utils
import
triton
from
vllm.triton_utils
import
triton
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
def
with_triton_mode
(
fn
):
def
with_triton_mode
(
fn
):
...
...
benchmarks/kernels/benchmark_activation.py
View file @
006693ed
...
@@ -10,7 +10,8 @@ import vllm.model_executor.layers.activation # noqa F401
...
@@ -10,7 +10,8 @@ import vllm.model_executor.layers.activation # noqa F401
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
triton
from
vllm.triton_utils
import
triton
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
batch_size_range
=
[
1
,
16
,
32
,
64
,
128
]
batch_size_range
=
[
1
,
16
,
32
,
64
,
128
]
seq_len_range
=
[
1
,
16
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
]
seq_len_range
=
[
1
,
16
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
]
...
...
benchmarks/kernels/benchmark_bitblas.py
View file @
006693ed
...
@@ -28,7 +28,7 @@ except ImportError as e:
...
@@ -28,7 +28,7 @@ except ImportError as e:
from
bitblas
import
Matmul
,
MatmulConfig
,
auto_detect_nvidia_target
from
bitblas
import
Matmul
,
MatmulConfig
,
auto_detect_nvidia_target
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
parser
=
FlexibleArgumentParser
(
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark BitBLAS int4 on a specific target."
description
=
"Benchmark BitBLAS int4 on a specific target."
...
...
benchmarks/kernels/benchmark_cutlass_fp4_moe.py
View file @
006693ed
...
@@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
...
@@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
cutlass_moe_fp4
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
cutlass_moe_fp4
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
,
fused_topk
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
,
fused_topk
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
WEIGHT_SHAPES_MOE
=
{
WEIGHT_SHAPES_MOE
=
{
"nvidia/DeepSeek-R1-FP4"
:
[
"nvidia/DeepSeek-R1-FP4"
:
[
...
...
benchmarks/kernels/benchmark_cutlass_moe_fp8.py
View file @
006693ed
...
@@ -14,7 +14,7 @@ from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_confi
...
@@ -14,7 +14,7 @@ from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_confi
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
cutlass_moe_fp8
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
cutlass_moe_fp8
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
,
fused_topk
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
,
fused_topk
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
# Weight shapes for different models: [num_experts, topk, hidden_size,
# Weight shapes for different models: [num_experts, topk, hidden_size,
# intermediate_size]
# intermediate_size]
...
...
benchmarks/kernels/benchmark_device_communicators.py
View file @
006693ed
...
@@ -22,8 +22,8 @@ Example:
...
@@ -22,8 +22,8 @@ Example:
import
json
import
json
import
os
import
os
import
time
import
time
from
collections.abc
import
Callable
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
from
typing
import
Callable
,
Optional
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -39,7 +39,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import (
...
@@ -39,7 +39,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import (
)
)
from
vllm.distributed.device_communicators.symm_mem
import
SymmMemCommunicator
from
vllm.distributed.device_communicators.symm_mem
import
SymmMemCommunicator
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -264,12 +264,12 @@ class CommunicatorBenchmark:
...
@@ -264,12 +264,12 @@ class CommunicatorBenchmark:
def
benchmark_allreduce_single
(
def
benchmark_allreduce_single
(
self
,
self
,
sequence_length
:
int
,
sequence_length
:
int
,
allreduce_fn
:
Callable
[[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]
],
allreduce_fn
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
|
None
],
should_use_fn
:
Callable
[[
torch
.
Tensor
],
bool
],
should_use_fn
:
Callable
[[
torch
.
Tensor
],
bool
],
context
,
context
,
num_warmup
:
int
,
num_warmup
:
int
,
num_trials
:
int
,
num_trials
:
int
,
)
->
Optional
[
float
]
:
)
->
float
|
None
:
"""Benchmark method with CUDA graph optimization."""
"""Benchmark method with CUDA graph optimization."""
try
:
try
:
# Create test tensor (2D: sequence_length x hidden_size)
# Create test tensor (2D: sequence_length x hidden_size)
...
...
benchmarks/kernels/benchmark_fused_collective.py
0 → 100644
View file @
006693ed
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark for FlashInfer fused collective operations vs standard operations.
This benchmark compares:
1. FlashInfer's trtllm_allreduce_fusion (fused allreduce + rmsnorm + optional quant)
2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations
Usage with torchrun:
torchrun --nproc_per_node=2 benchmark_fused_collective.py
"""
import
argparse
import
itertools
import
os
import
time
import
pandas
as
pd
import
torch
# type: ignore
import
torch.distributed
as
dist
# type: ignore
from
vllm.config.vllm
import
CompilationConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.distributed
import
(
get_tp_group
,
tensor_model_parallel_all_reduce
,
)
from
vllm.distributed.parallel_state
import
(
graph_capture
,
init_distributed_environment
,
initialize_model_parallel
,
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.layernorm
import
RMSNorm
# noqa
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
# noqa
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
# noqa
from
vllm.platforms
import
current_platform
# noqa
RMS_NORM_OP
=
torch
.
ops
.
_C
.
rms_norm
FUSED_ADD_RMS_NORM_OP
=
torch
.
ops
.
_C
.
fused_add_rms_norm
RMS_NORM_STATIC_FP8_QUANT_OP
=
torch
.
ops
.
_C
.
rms_norm_static_fp8_quant
FUSED_ADD_RMS_NORM_STATIC_FP8_QUANT_OP
=
(
torch
.
ops
.
_C
.
fused_add_rms_norm_static_fp8_quant
)
SCALED_FP4_QUANT_OP
=
torch
.
ops
.
_C
.
scaled_fp4_quant
logger
=
init_logger
(
__name__
)
# Try to import FlashInfer
try
:
import
flashinfer.comm
as
flashinfer_comm
# type: ignore
if
not
hasattr
(
flashinfer_comm
,
"trtllm_allreduce_fusion"
):
flashinfer_comm
=
None
logger
.
warning
(
"FlashInfer comm module found but missing trtllm_allreduce_fusion"
)
except
ImportError
:
flashinfer_comm
=
None
logger
.
warning
(
"FlashInfer not found, only benchmarking standard operations"
)
# Constants
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
MiB
=
1024
*
1024
# FlashInfer max sizes per world size
# Enable 64MB for 2, 4, 8 world sizes to verify large input sizes
# use --disable-oneshot to disable oneshot mode for very large input sizes
_FI_MAX_SIZES
=
{
2
:
64
*
MiB
,
# 64MB
4
:
64
*
MiB
,
# 64MB
8
:
64
*
MiB
,
# 64MB
}
# Global workspace tensor for FlashInfer
_FI_WORKSPACE_TENSOR
=
None
def
setup_flashinfer_workspace
(
world_size
:
int
,
rank
:
int
,
hidden_dim
:
int
,
max_token_num
:
int
,
use_fp32_lamport
:
bool
=
False
,
):
"""Setup FlashInfer workspace for fused allreduce operations."""
global
_FI_WORKSPACE_TENSOR
if
flashinfer_comm
is
None
:
return
None
,
None
if
world_size
not
in
_FI_MAX_SIZES
:
logger
.
warning
(
"FlashInfer not supported for world size %s"
,
world_size
)
return
None
,
None
try
:
# Create IPC workspace
ipc_handles
,
workspace_tensor
=
(
flashinfer_comm
.
trtllm_create_ipc_workspace_for_all_reduce_fusion
(
tp_rank
=
rank
,
tp_size
=
world_size
,
max_token_num
=
max_token_num
,
hidden_dim
=
hidden_dim
,
group
=
get_tp_group
().
device_group
,
use_fp32_lamport
=
use_fp32_lamport
,
)
)
_FI_WORKSPACE_TENSOR
=
workspace_tensor
return
ipc_handles
,
workspace_tensor
except
Exception
as
e
:
logger
.
error
(
"Failed to setup FlashInfer workspace: %s"
,
e
)
return
None
,
None
def
cleanup_flashinfer_workspace
(
ipc_handles
):
"""Cleanup FlashInfer workspace."""
if
flashinfer_comm
is
None
or
ipc_handles
is
None
:
return
try
:
group
=
get_tp_group
().
device_group
flashinfer_comm
.
trtllm_destroy_ipc_workspace_for_all_reduce
(
ipc_handles
,
group
)
except
Exception
as
e
:
logger
.
error
(
"Failed to cleanup FlashInfer workspace: %s"
,
e
)
class
FlashInferFusedAllReduceParams
:
"""Parameters for FlashInfer fused allreduce operations."""
def
__init__
(
self
,
rank
:
int
,
world_size
:
int
,
use_fp32_lamport
:
bool
=
False
,
max_token_num
:
int
=
1024
,
):
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
use_fp32_lamport
=
use_fp32_lamport
self
.
trigger_completion_at_end
=
True
self
.
launch_with_pdl
=
True
self
.
fp32_acc
=
True
self
.
max_token_num
=
max_token_num
def
get_trtllm_fused_allreduce_kwargs
(
self
):
return
{
"world_rank"
:
self
.
rank
,
"world_size"
:
self
.
world_size
,
"launch_with_pdl"
:
self
.
launch_with_pdl
,
"trigger_completion_at_end"
:
self
.
trigger_completion_at_end
,
"fp32_acc"
:
self
.
fp32_acc
,
}
def
flashinfer_fused_allreduce_rmsnorm
(
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
rms_gamma
:
torch
.
Tensor
,
rms_eps
:
float
,
allreduce_params
:
"FlashInferFusedAllReduceParams"
,
use_oneshot
:
bool
,
norm_out
:
torch
.
Tensor
|
None
=
None
,
):
"""FlashInfer fused allreduce + rmsnorm operation."""
if
flashinfer_comm
is
None
or
_FI_WORKSPACE_TENSOR
is
None
:
raise
RuntimeError
(
"FlashInfer not available or workspace not initialized"
)
if
norm_out
is
None
:
norm_out
=
input_tensor
residual_out
=
residual
else
:
residual_out
=
input_tensor
flashinfer_comm
.
trtllm_allreduce_fusion
(
allreduce_in
=
input_tensor
,
token_num
=
input_tensor
.
shape
[
0
],
residual_in
=
residual
,
residual_out
=
residual_out
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
hidden_dim
=
input_tensor
.
shape
[
-
1
],
workspace_ptrs
=
_FI_WORKSPACE_TENSOR
,
pattern_code
=
flashinfer_comm
.
AllReduceFusionPattern
.
kARResidualRMSNorm
,
allreduce_out
=
None
,
quant_out
=
None
,
scale_out
=
None
,
layout_code
=
flashinfer_comm
.
QuantizationSFLayout
.
SWIZZLED_128x4
,
scale_factor
=
None
,
use_oneshot
=
use_oneshot
,
**
allreduce_params
.
get_trtllm_fused_allreduce_kwargs
(),
)
def
flashinfer_fused_allreduce_rmsnorm_fp8_quant
(
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
rms_gamma
:
torch
.
Tensor
,
rms_eps
:
float
,
scale_factor
:
torch
.
Tensor
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
use_oneshot
:
bool
=
True
,
norm_out
:
torch
.
Tensor
|
None
=
None
,
quant_out
:
torch
.
Tensor
|
None
=
None
,
):
"""FlashInfer fused allreduce + rmsnorm + FP8 quantization."""
if
flashinfer_comm
is
None
or
_FI_WORKSPACE_TENSOR
is
None
:
raise
RuntimeError
(
"FlashInfer not available or workspace not initialized"
)
if
norm_out
is
None
:
norm_out
=
input_tensor
residual_out
=
residual
else
:
residual_out
=
input_tensor
flashinfer_comm
.
trtllm_allreduce_fusion
(
allreduce_in
=
input_tensor
,
token_num
=
input_tensor
.
shape
[
0
],
residual_in
=
residual
,
residual_out
=
residual_out
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
hidden_dim
=
input_tensor
.
shape
[
-
1
],
workspace_ptrs
=
_FI_WORKSPACE_TENSOR
,
pattern_code
=
flashinfer_comm
.
AllReduceFusionPattern
.
kARResidualRMSNormFP8Quant
,
allreduce_out
=
None
,
quant_out
=
quant_out
,
scale_out
=
None
,
layout_code
=
flashinfer_comm
.
QuantizationSFLayout
.
SWIZZLED_128x4
,
scale_factor
=
scale_factor
,
use_oneshot
=
use_oneshot
,
**
allreduce_params
.
get_trtllm_fused_allreduce_kwargs
(),
)
def
flashinfer_fused_allreduce_rmsnorm_fp4_quant
(
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
rms_gamma
:
torch
.
Tensor
,
rms_eps
:
float
,
input_global_scale
:
torch
.
Tensor
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
quant_out
:
torch
.
Tensor
,
use_oneshot
:
bool
,
output_scale
:
torch
.
Tensor
,
norm_out
:
torch
.
Tensor
|
None
=
None
,
):
"""FlashInfer fused allreduce + rmsnorm + FP4 quantization."""
if
flashinfer_comm
is
None
or
_FI_WORKSPACE_TENSOR
is
None
:
raise
RuntimeError
(
"FlashInfer not available or workspace not initialized"
)
if
norm_out
is
None
:
norm_out
=
input_tensor
residual_out
=
residual
else
:
residual_out
=
input_tensor
flashinfer_comm
.
trtllm_allreduce_fusion
(
allreduce_in
=
input_tensor
,
token_num
=
input_tensor
.
shape
[
0
],
residual_in
=
residual
,
residual_out
=
residual_out
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
hidden_dim
=
input_tensor
.
shape
[
-
1
],
workspace_ptrs
=
_FI_WORKSPACE_TENSOR
,
pattern_code
=
flashinfer_comm
.
AllReduceFusionPattern
.
kARResidualRMSNormFP4Quant
,
allreduce_out
=
None
,
quant_out
=
quant_out
,
scale_out
=
output_scale
,
layout_code
=
flashinfer_comm
.
QuantizationSFLayout
.
SWIZZLED_128x4
,
scale_factor
=
input_global_scale
,
use_oneshot
=
use_oneshot
,
**
allreduce_params
.
get_trtllm_fused_allreduce_kwargs
(),
)
class
VllmFusedAllreduce
:
def
__init__
(
self
,
hidden_dim
,
dtype
):
self
.
rms_eps
=
1e-6
self
.
rms_norm
=
RMSNorm
(
hidden_dim
,
eps
=
self
.
rms_eps
,
dtype
=
dtype
)
self
.
fp8_quant
=
QuantFP8
(
static
=
True
,
group_shape
=
GroupShape
.
PER_TENSOR
,
)
def
allreduce_rmsnorm
(
self
,
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
):
allreduce_out
=
tensor_model_parallel_all_reduce
(
input_tensor
)
return
self
.
rms_norm
(
allreduce_out
,
residual
)
def
allreduce_rmsnorm_fp8_quant
(
self
,
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
scale_factor
:
torch
.
Tensor
,
):
allreduce_out
=
tensor_model_parallel_all_reduce
(
input_tensor
)
rms_out
=
self
.
rms_norm
(
allreduce_out
,
residual
)
if
residual
is
None
:
quant_out
=
self
.
fp8_quant
(
rms_out
,
scale_factor
)
return
quant_out
else
:
rms_out
,
residual_out
=
rms_out
quant_out
=
self
.
fp8_quant
(
rms_out
,
scale_factor
)
return
quant_out
,
residual_out
def
allreduce_rmsnorm_fp4_quant
(
self
,
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
input_global_scale
:
torch
.
Tensor
,
quant_out
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
):
allreduce_out
=
tensor_model_parallel_all_reduce
(
input_tensor
)
rms_out
=
self
.
rms_norm
(
allreduce_out
,
residual
)
if
residual
is
None
:
SCALED_FP4_QUANT_OP
(
quant_out
,
rms_out
,
output_scale
,
input_global_scale
)
return
quant_out
,
output_scale
else
:
rms_out
,
residual_out
=
rms_out
SCALED_FP4_QUANT_OP
(
quant_out
,
rms_out
,
output_scale
,
input_global_scale
)
return
quant_out
,
residual_out
,
output_scale
def
create_test_tensors
(
num_tokens
:
int
,
hidden_dim
:
int
,
dtype
:
torch
.
dtype
,
use_residual
:
bool
=
True
):
"""Create test tensors for benchmarking."""
input_tensor
=
torch
.
randn
(
num_tokens
,
hidden_dim
,
dtype
=
dtype
)
residual
=
(
torch
.
randn_like
(
input_tensor
)
if
use_residual
else
torch
.
zeros_like
(
input_tensor
)
)
rms_gamma
=
torch
.
ones
(
hidden_dim
,
dtype
=
dtype
)
norm_out
=
None
if
use_residual
else
torch
.
empty_like
(
input_tensor
)
# Quantization scales
scale_fp8
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
scale_fp4
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
quant_out_fp8
=
torch
.
empty_like
(
input_tensor
,
dtype
=
FP8_DTYPE
)
# Pre-allocate FP4 output tensors (to avoid allocation overhead in benchmarks)
fp4_quant_out
=
torch
.
empty
((
num_tokens
,
hidden_dim
//
2
),
dtype
=
torch
.
uint8
)
fp4_output_scale
=
torch
.
empty
((
128
,
4
),
dtype
=
torch
.
int32
)
return
(
input_tensor
,
norm_out
,
residual
,
rms_gamma
,
scale_fp8
,
quant_out_fp8
,
scale_fp4
,
fp4_quant_out
,
fp4_output_scale
,
)
def
benchmark_operation
(
operation_func
,
*
args
,
warmup
:
int
=
5
,
trials
:
int
=
20
,
**
kwargs
):
"""Benchmark a single operation using CUDA graphs."""
# Warmup before graph capture
for
_
in
range
(
warmup
):
operation_func
(
*
args
,
**
kwargs
)
torch
.
cuda
.
synchronize
()
# Create CUDA graph
graph
=
torch
.
cuda
.
CUDAGraph
()
num_op_per_cudagraph
=
10
# Use vLLM's graph_capture to make tensor_model_parallel_all_reduce graph-safe
device
=
torch
.
device
(
f
"cuda:
{
torch
.
cuda
.
current_device
()
}
"
)
with
graph_capture
(
device
=
device
),
torch
.
cuda
.
graph
(
graph
):
for
_
in
range
(
num_op_per_cudagraph
):
operation_func
(
*
args
,
**
kwargs
)
# Graph warmup
torch
.
cuda
.
synchronize
()
for
_
in
range
(
warmup
):
graph
.
replay
()
# Benchmark with CUDA graph
torch
.
cuda
.
synchronize
()
start_time
=
time
.
perf_counter
()
for
_
in
range
(
trials
//
num_op_per_cudagraph
):
# operation_func(*args, **kwargs)
graph
.
replay
()
torch
.
cuda
.
synchronize
()
end_time
=
time
.
perf_counter
()
avg_time_ms
=
((
end_time
-
start_time
)
/
trials
)
*
1000
return
avg_time_ms
def
run_benchmarks
(
num_tokens
:
int
,
hidden_dim
:
int
,
dtype
:
torch
.
dtype
,
use_residual
:
bool
,
allreduce_params
:
FlashInferFusedAllReduceParams
|
None
,
quant_modes
:
set
[
str
],
no_oneshot
:
bool
,
):
"""Run all benchmarks for given configuration.
Args:
quant_mode: "none", "fp8_only", "fp4_only", or "all"
"""
(
input_tensor
,
norm_out
,
residual
,
rms_gamma
,
scale_fp8
,
quant_out_fp8
,
scale_fp4
,
fp4_quant_out
,
fp4_output_scale
,
)
=
create_test_tensors
(
num_tokens
,
hidden_dim
,
dtype
,
use_residual
)
rms_eps
=
1e-6
results
=
{}
vllm_fused_allreduce
=
VllmFusedAllreduce
(
hidden_dim
,
dtype
)
use_oneshot_options
=
[
False
]
if
no_oneshot
else
[
True
,
False
]
# Create RMSNorm and QuantFP8 layers once for native benchmarks
if
"none"
in
quant_modes
:
# Standard AllReduce + RMSNorm
for
custom_op
in
[
"-rms_norm"
,
"+rms_norm"
]:
with
set_current_vllm_config
(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
custom_op
]))
):
try
:
suffix
=
(
"_custom_rms_norm"
if
"+"
in
custom_op
else
"_native_rms_norm"
)
time_ms
=
benchmark_operation
(
vllm_fused_allreduce
.
allreduce_rmsnorm
,
input_tensor
,
residual
=
residual
,
)
results
[
f
"standard_allreduce_
{
suffix
}
"
]
=
time_ms
except
Exception
as
e
:
logger
.
error
(
"Standard AllReduce+RMSNorm failed: %s"
,
e
)
results
[
f
"standard_allreduce_
{
suffix
}
"
]
=
float
(
"inf"
)
# Standard AllReduce + RMSNorm Native Compiled
with
set_current_vllm_config
(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
"-rms_norm"
]))
):
try
:
standard_allreduce_rmsnorm_native_compiled
=
torch
.
compile
(
vllm_fused_allreduce
.
allreduce_rmsnorm
,
fullgraph
=
True
,
dynamic
=
False
,
)
time_ms
=
benchmark_operation
(
standard_allreduce_rmsnorm_native_compiled
,
input_tensor
,
residual
=
residual
,
)
results
[
"standard_allreduce_rmsnorm_native_compiled"
]
=
time_ms
except
Exception
as
e
:
logger
.
error
(
"Standard AllReduce+RMSNorm Native Compiled failed: %s"
,
e
)
results
[
"standard_allreduce_rmsnorm_native_compiled"
]
=
float
(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm Oneshot/Twoshot
if
flashinfer_comm
is
not
None
and
allreduce_params
is
not
None
:
for
use_oneshot
in
use_oneshot_options
:
suffix
=
"_oneshot"
if
use_oneshot
else
"_twoshot"
try
:
time_ms
=
benchmark_operation
(
flashinfer_fused_allreduce_rmsnorm
,
input_tensor
,
residual
=
residual
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
allreduce_params
=
allreduce_params
,
use_oneshot
=
use_oneshot
,
)
results
[
f
"flashinfer_fused_allreduce_rmsnorm
{
suffix
}
"
]
=
time_ms
except
Exception
as
e
:
logger
.
error
(
"FlashInfer Fused AllReduce+RMSNorm failed: %s"
,
e
)
results
[
f
"flashinfer_fused_allreduce_rmsnorm
{
suffix
}
"
]
=
float
(
"inf"
)
if
"fp8"
in
quant_modes
:
# Standard AllReduce + RMSNorm + FP8 Quant
for
rms_norm_custom_op
in
[
"-rms_norm"
,
"+rms_norm"
]:
suffix
=
(
"_custom_rms_norm"
if
"+"
in
rms_norm_custom_op
else
"_native_rms_norm"
)
for
quant_fp8_custom_op
in
[
"-quant_fp8"
,
"+quant_fp8"
]:
suffix
+=
(
"_custom_quant_fp8"
if
"+"
in
quant_fp8_custom_op
else
"_native_quant_fp8"
)
with
set_current_vllm_config
(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
rms_norm_custom_op
,
quant_fp8_custom_op
]
)
)
):
try
:
time_ms
=
benchmark_operation
(
vllm_fused_allreduce
.
allreduce_rmsnorm_fp8_quant
,
input_tensor
,
residual
=
residual
,
scale_factor
=
scale_fp8
,
)
results
[
f
"standard_allreduce
{
suffix
}
"
]
=
time_ms
except
Exception
as
e
:
logger
.
error
(
"Standard AllReduce+RMSNorm+FP8 failed: %s"
,
e
)
results
[
f
"standard_allreduce
{
suffix
}
"
]
=
float
(
"inf"
)
# Standard AllReduce + RMSNorm + FP8 Quant Native Compiled
with
set_current_vllm_config
(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
"-rms_norm"
,
"-quant_fp8"
]
)
)
):
try
:
standard_allreduce_rmsnorm_fp8_quant_native_compiled
=
torch
.
compile
(
vllm_fused_allreduce
.
allreduce_rmsnorm_fp8_quant
,
fullgraph
=
True
,
dynamic
=
False
,
)
time_ms
=
benchmark_operation
(
standard_allreduce_rmsnorm_fp8_quant_native_compiled
,
input_tensor
,
residual
=
residual
,
scale_factor
=
scale_fp8
,
)
results
[
"standard_allreduce_rmsnorm_fp8_quant_native_compiled"
]
=
(
time_ms
)
except
Exception
as
e
:
logger
.
error
(
"Standard AllReduce+RMSNorm+FP8 Native Compiled failed: %s"
,
e
)
results
[
"standard_allreduce_rmsnorm_fp8_quant_native_compiled"
]
=
float
(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot
if
flashinfer_comm
is
not
None
and
allreduce_params
is
not
None
:
for
use_oneshot
in
use_oneshot_options
:
suffix
=
"_oneshot"
if
use_oneshot
else
"_twoshot"
try
:
time_ms
=
benchmark_operation
(
flashinfer_fused_allreduce_rmsnorm_fp8_quant
,
input_tensor
,
norm_out
=
norm_out
,
residual
=
residual
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
scale_factor
=
scale_fp8
,
quant_out
=
quant_out_fp8
,
allreduce_params
=
allreduce_params
,
use_oneshot
=
use_oneshot
,
)
results
[
f
"flashinfer_fused_allreduce_rmsnorm_fp8_quant
{
suffix
}
"
]
=
(
time_ms
)
except
Exception
as
e
:
logger
.
error
(
"FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s"
,
e
,
)
results
[
f
"flashinfer_fused_allreduce_rmsnorm_fp8_quant
{
suffix
}
"
]
=
(
float
(
"inf"
)
)
if
"fp4"
in
quant_modes
and
current_platform
.
has_device_capability
(
100
):
# Standard AllReduce + RMSNorm + FP4 Quant
for
rms_norm_custom_op
in
[
"-rms_norm"
,
"+rms_norm"
]:
suffix
=
(
"_custom_rms_norm"
if
"+"
in
rms_norm_custom_op
else
"_native_rms_norm"
)
with
set_current_vllm_config
(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
rms_norm_custom_op
]
)
)
):
try
:
time_ms
=
benchmark_operation
(
vllm_fused_allreduce
.
allreduce_rmsnorm_fp4_quant
,
input_tensor
,
residual
=
residual
,
input_global_scale
=
scale_fp4
,
quant_out
=
fp4_quant_out
,
output_scale
=
fp4_output_scale
,
)
results
[
f
"standard_allreduce_
{
suffix
}
_fp4_quant"
]
=
time_ms
except
Exception
as
e
:
logger
.
error
(
"Standard AllReduce+RMSNorm+FP4 failed: %s"
,
e
)
results
[
f
"standard_allreduce_
{
suffix
}
_fp4_quant"
]
=
float
(
"inf"
)
# Standard AllReduce + RMSNorm + FP4 Quant Native Compiled
with
set_current_vllm_config
(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
"-rms_norm"
]))
):
try
:
standard_allreduce_rmsnorm_fp4_quant_native_compiled
=
torch
.
compile
(
vllm_fused_allreduce
.
allreduce_rmsnorm_fp4_quant
,
fullgraph
=
True
,
dynamic
=
False
,
)
time_ms
=
benchmark_operation
(
standard_allreduce_rmsnorm_fp4_quant_native_compiled
,
input_tensor
,
residual
=
residual
,
quant_out
=
fp4_quant_out
,
input_global_scale
=
scale_fp4
,
output_scale
=
fp4_output_scale
,
)
results
[
"standard_allreduce_rmsnorm_fp4_quant_native_compiled"
]
=
(
time_ms
)
except
Exception
as
e
:
logger
.
error
(
"Standard AllReduce+RMSNorm+FP4 Native Compiled failed: %s"
,
e
)
results
[
"standard_allreduce_rmsnorm_fp4_quant_native_compiled"
]
=
float
(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot
if
flashinfer_comm
is
not
None
and
allreduce_params
is
not
None
:
for
use_oneshot
in
use_oneshot_options
:
suffix
=
"_oneshot"
if
use_oneshot
else
"_twoshot"
try
:
time_ms
=
benchmark_operation
(
flashinfer_fused_allreduce_rmsnorm_fp4_quant
,
input_tensor
,
residual
=
residual
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
input_global_scale
=
scale_fp4
,
allreduce_params
=
allreduce_params
,
quant_out
=
fp4_quant_out
,
output_scale
=
fp4_output_scale
,
use_oneshot
=
use_oneshot
,
)
results
[
f
"flashinfer_fused_allreduce_rmsnorm_fp4_quant
{
suffix
}
"
]
=
(
time_ms
)
except
Exception
as
e
:
logger
.
error
(
"FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s"
,
e
,
)
results
[
f
"flashinfer_fused_allreduce_rmsnorm_fp4_quant
{
suffix
}
"
]
=
(
float
(
"inf"
)
)
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Two-shot
if
flashinfer_comm
is
not
None
and
allreduce_params
is
not
None
:
try
:
time_ms
=
benchmark_operation
(
flashinfer_fused_allreduce_rmsnorm_fp4_quant
,
input_tensor
,
residual
=
residual
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
input_global_scale
=
scale_fp4
,
allreduce_params
=
allreduce_params
,
quant_out
=
fp4_quant_out
,
output_scale
=
fp4_output_scale
,
use_oneshot
=
False
,
)
results
[
"flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"
]
=
(
time_ms
)
except
Exception
as
e
:
logger
.
error
(
"FlashInfer Fused AllReduce+RMSNorm+FP4 Two-shot failed: %s"
,
e
,
)
results
[
"flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"
]
=
float
(
"inf"
)
return
results
def
prepare_results_with_speedups
(
results_dict
):
"""Prepare results with speedup calculations based on dynamic baseline selection."""
prepared_results
=
[]
# Determine the fastest baseline for each operation type
def
get_fastest_baseline
(
op_name
,
results_dict
):
"""Get the fastest baseline between standard and native_compiled versions."""
if
"fp8_quant"
in
op_name
:
candidates
=
[
"standard_allreduce_rmsnorm_fp8_quant"
,
"standard_allreduce_rmsnorm_fp8_quant_native_compiled"
,
]
elif
"fp4_quant"
in
op_name
:
candidates
=
[
"standard_allreduce_rmsnorm_fp4_quant"
,
"standard_allreduce_rmsnorm_fp4_quant_native_compiled"
,
]
else
:
candidates
=
[
"standard_allreduce_rmsnorm"
,
"standard_allreduce_rmsnorm_native_compiled"
,
]
# Find the fastest among available candidates
fastest_time
=
float
(
"inf"
)
fastest_baseline
=
None
for
candidate
in
candidates
:
if
(
candidate
in
results_dict
and
results_dict
[
candidate
]
!=
float
(
"inf"
)
and
results_dict
[
candidate
]
<
fastest_time
):
fastest_time
=
results_dict
[
candidate
]
fastest_baseline
=
candidate
return
fastest_baseline
# Create dynamic baseline mapping
dynamic_baseline_mapping
=
{}
for
op_name
in
results_dict
:
if
(
op_name
.
startswith
(
"flashinfer_"
)
or
op_name
.
startswith
(
"standard_"
)
and
not
op_name
.
endswith
(
"_native_compiled"
)
):
dynamic_baseline_mapping
[
op_name
]
=
get_fastest_baseline
(
op_name
,
results_dict
)
for
op_name
,
time_ms
in
results_dict
.
items
():
if
time_ms
==
float
(
"inf"
):
speedup_str
=
"FAILED"
time_str
=
"FAILED"
else
:
time_str
=
f
"
{
time_ms
:.
3
f
}
"
# Find the appropriate baseline for this operation
baseline_op
=
dynamic_baseline_mapping
.
get
(
op_name
)
if
baseline_op
and
baseline_op
in
results_dict
:
baseline_time
=
results_dict
[
baseline_op
]
if
baseline_time
!=
float
(
"inf"
)
and
baseline_time
>
0
:
speedup
=
baseline_time
/
time_ms
speedup_str
=
f
"
{
speedup
:.
2
f
}
x"
else
:
speedup_str
=
"N/A"
else
:
# For baseline operations, determine if this is the fastest baseline
if
op_name
.
endswith
(
"_native_compiled"
)
or
(
op_name
.
startswith
(
"standard_"
)
and
not
op_name
.
endswith
(
"_native_compiled"
)
):
fastest_baseline
=
get_fastest_baseline
(
op_name
,
results_dict
)
if
fastest_baseline
==
op_name
:
speedup_str
=
"baseline"
else
:
if
fastest_baseline
and
fastest_baseline
in
results_dict
:
baseline_time
=
results_dict
[
fastest_baseline
]
if
baseline_time
!=
float
(
"inf"
)
and
baseline_time
>
0
:
speedup
=
baseline_time
/
time_ms
speedup_str
=
f
"
{
speedup
:.
2
f
}
x"
else
:
speedup_str
=
"N/A"
else
:
speedup_str
=
"N/A"
else
:
speedup_str
=
"N/A"
prepared_results
.
append
(
{
"operation"
:
op_name
,
"time_ms"
:
time_ms
,
"time_str"
:
time_str
,
"speedup_str"
:
speedup_str
,
}
)
return
prepared_results
def
print_results
(
results_dict
,
num_tokens
,
hidden_dim
,
dtype
,
use_residual
,
quant_modes
,
input_size_mb
,
):
"""Print benchmark results in a formatted table."""
print
(
f
"
\n
{
'='
*
80
}
"
)
print
(
f
"Results: num_tokens=
{
num_tokens
}
, hidden_dim=
{
hidden_dim
}
"
f
"(input size:
{
input_size_mb
:.
2
f
}
MB)"
)
print
(
f
"dtype=
{
dtype
}
, residual=
{
'yes'
if
use_residual
else
'no'
}
, "
f
"quant_modes=
{
','
.
join
(
sorted
(
list
(
quant_modes
)))
}
"
)
print
(
f
"
{
'='
*
80
}
"
)
print
(
f
"
{
'Operation'
:
<
50
}
{
'Time (ms)'
:
<
12
}
{
'Speedup'
:
<
10
}
"
)
print
(
f
"
{
'-'
*
80
}
"
)
# Prepare results with speedup calculations
prepared_results
=
prepare_results_with_speedups
(
results_dict
)
for
result
in
prepared_results
:
if
result
[
"time_ms"
]
==
float
(
"inf"
):
time_display
=
result
[
"time_str"
]
else
:
time_display
=
f
"
{
result
[
'time_ms'
]:.
3
f
}
"
print
(
f
"
{
result
[
'operation'
]:
<
50
}
{
time_display
:
<
12
}
{
result
[
'speedup_str'
]:
<
10
}
"
)
def
format_results_markdown
(
all_results
:
list
[
dict
],
world_size
:
int
,
args
:
argparse
.
Namespace
)
->
str
:
"""Format all benchmark results as markdown."""
lines
:
list
[
str
]
=
[]
lines
.
append
(
"# FlashInfer Fused Collective Operations Benchmark Results"
)
lines
.
append
(
""
)
lines
.
append
(
f
"**World Size:**
{
world_size
}
"
)
lines
.
append
(
f
"**Hidden Dimension:**
{
args
.
hidden_dim
}
"
)
lines
.
append
(
f
"**Warmup Iterations:**
{
args
.
warmup
}
"
)
lines
.
append
(
f
"**Benchmark Trials:**
{
args
.
trials
}
"
)
modes
=
","
.
join
(
all_results
[
0
][
"quant_modes"
])
if
all_results
else
"N/A"
lines
.
append
(
f
"**Quantization Modes:**
{
modes
}
"
)
lines
.
append
(
""
)
lines
.
append
(
"---"
)
lines
.
append
(
""
)
for
entry
in
all_results
:
num_tokens
=
entry
[
"num_tokens"
]
dtype
=
entry
[
"dtype"
]
use_residual
=
entry
[
"use_residual"
]
results_dict
=
entry
[
"results"
]
input_size_mb
=
entry
[
"input_size_mb"
]
residual_str
=
"with residual"
if
use_residual
else
"no residual"
lines
.
append
(
f
"## Configuration: num_tokens=
{
num_tokens
}
, dtype=
{
dtype
}
,
{
residual_str
}
"
)
lines
.
append
(
f
"**Input Size:**
{
input_size_mb
:.
2
f
}
MB"
)
lines
.
append
(
""
)
prepared
=
prepare_results_with_speedups
(
results_dict
)
# Build DataFrame for markdown export
rows
=
[
{
"Operation"
:
r
[
"operation"
].
replace
(
"_"
,
" "
).
title
(),
"Time (ms)"
:
r
[
"time_str"
],
"Speedup"
:
r
[
"speedup_str"
],
}
for
r
in
prepared
]
df
=
pd
.
DataFrame
(
rows
)
if
df
.
empty
:
lines
.
append
(
"No results."
)
else
:
lines
.
append
(
df
.
to_markdown
(
index
=
False
))
lines
.
append
(
""
)
return
"
\n
"
.
join
(
lines
)
def
save_results_to_file
(
all_results
:
list
[
dict
],
world_size
:
int
,
args
:
argparse
.
Namespace
,
rank
:
int
):
"""Save benchmark results to markdown file (only on rank 0)."""
if
rank
!=
0
:
return
if
not
all_results
:
logger
.
warning
(
"No results to save"
)
return
output_path
=
args
.
output_file
try
:
markdown_content
=
format_results_markdown
(
all_results
,
world_size
,
args
)
with
open
(
output_path
,
"a"
)
as
f
:
f
.
write
(
markdown_content
)
except
Exception
as
e
:
logger
.
error
(
"Failed to save results to file: %s"
,
e
)
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Benchmark fused collective operations"
)
parser
.
add_argument
(
"--num-tokens"
,
type
=
int
,
nargs
=
"+"
,
default
=
[
128
,
512
,
1024
,
2048
],
help
=
"Numbers of tokens to test"
,
)
parser
.
add_argument
(
"--hidden-dim"
,
type
=
int
,
default
=
8192
,
help
=
"Hidden dimension size"
)
parser
.
add_argument
(
"--dtypes"
,
type
=
str
,
nargs
=
"+"
,
default
=
[
"bfloat16"
],
choices
=
[
"float16"
,
"bfloat16"
,
"float32"
],
help
=
"Data types to test"
,
)
parser
.
add_argument
(
"--no-residual"
,
action
=
"store_true"
,
help
=
"Skip residual connection tests"
,
)
parser
.
add_argument
(
"--quant-modes"
,
type
=
str
,
default
=
"none,fp8,fp4"
,
help
=
(
"Comma-separated quantization modes to run: none, fp8, fp4. "
"Default: none,fp8,fp4"
),
)
parser
.
add_argument
(
"--warmup"
,
type
=
int
,
default
=
5
,
help
=
"Number of warmup iterations"
)
parser
.
add_argument
(
"--trials"
,
type
=
int
,
default
=
20
,
help
=
"Number of benchmark trials"
)
parser
.
add_argument
(
"--output-file"
,
type
=
str
,
help
=
"""Output file path for markdown results
(default: benchmark_results_<timestamp>.md)
"""
,
)
parser
.
add_argument
(
"--no-oneshot"
,
action
=
"store_true"
,
help
=
"Skip oneshot benchmarks"
,
)
args
=
parser
.
parse_args
()
# Check if running with torchrun (required for collective operations)
if
"RANK"
not
in
os
.
environ
or
"WORLD_SIZE"
not
in
os
.
environ
:
raise
RuntimeError
(
"Must run with torchrun for distributed benchmarking. "
"Example: torchrun --nproc_per_node=2 benchmark_fused_collective.py"
)
# Initialize distributed environment
rank
=
int
(
os
.
environ
[
"RANK"
])
world_size
=
int
(
os
.
environ
[
"WORLD_SIZE"
])
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
init_distributed_environment
()
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
# Validate world size (must be > 1 for collective operations)
if
world_size
<=
1
:
raise
ValueError
(
"World size must be > 1 for collective operations benchmarking. "
f
"Current world size:
{
world_size
}
. Use torchrun with --nproc_per_node > 1."
)
# Parse quantization modes
valid_quant_modes
=
{
"none"
,
"fp8"
,
"fp4"
}
raw_modes
=
[
m
.
strip
().
lower
()
for
m
in
(
args
.
quant_modes
or
""
).
split
(
","
)
if
m
.
strip
()
]
quant_modes
=
set
(
raw_modes
)
if
raw_modes
else
{
"none"
,
"fp8"
,
"fp4"
}
invalid
=
sorted
(
list
(
quant_modes
-
valid_quant_modes
))
if
invalid
:
raise
ValueError
(
f
"Invalid --quant-modes entries:
{
','
.
join
(
invalid
)
}
. "
f
"Valid options are:
{
','
.
join
(
sorted
(
valid_quant_modes
))
}
."
)
if
rank
==
0
:
logger
.
info
(
"Running benchmark with world_size=%s, rank=%s"
,
world_size
,
rank
)
logger
.
info
(
"Quantization modes: %s"
,
","
.
join
(
sorted
(
list
(
quant_modes
))))
if
flashinfer_comm
is
not
None
:
logger
.
info
(
"FlashInfer available - will benchmark fused operations"
,
)
else
:
logger
.
info
(
"FlashInfer not available - only benchmarking standard operations"
)
# Convert dtype strings to torch dtypes
dtype_map
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
,
"float32"
:
torch
.
float32
,
}
dtypes
=
[
dtype_map
[
dt
]
for
dt
in
args
.
dtypes
]
# Test configurations
residual_options
=
[
True
]
if
not
args
.
no_residual
else
[
False
]
configs
=
list
(
itertools
.
product
(
args
.
num_tokens
,
dtypes
,
residual_options
))
# Setup FlashInfer workspace if available
ipc_handles
=
None
allreduce_params
=
None
if
flashinfer_comm
is
not
None
:
# Use the largest hidden dimension for workspace setup
max_num_token
=
_FI_MAX_SIZES
.
get
(
world_size
)
//
(
args
.
hidden_dim
*
world_size
*
2
)
ipc_handles
,
workspace_tensor
=
setup_flashinfer_workspace
(
world_size
,
rank
,
args
.
hidden_dim
,
max_num_token
)
if
workspace_tensor
is
not
None
:
allreduce_params
=
FlashInferFusedAllReduceParams
(
rank
=
rank
,
world_size
=
world_size
,
max_token_num
=
max_num_token
,
)
# Collect all results for markdown export
all_results
=
[]
try
:
# Run benchmarks
for
num_tokens
,
dtype
,
use_residual
in
configs
:
if
rank
==
0
:
logger
.
info
(
"
\n
Testing: num_tokens=%s, hidden_dim=%s, dtype=%s, residual=%s"
,
num_tokens
,
args
.
hidden_dim
,
dtype
,
use_residual
,
)
results
=
run_benchmarks
(
num_tokens
,
args
.
hidden_dim
,
dtype
,
use_residual
,
allreduce_params
,
quant_modes
=
quant_modes
,
no_oneshot
=
args
.
no_oneshot
,
)
# Store results for markdown export
if
rank
==
0
:
# Calculate input size in MB
input_size_mb
=
(
num_tokens
*
args
.
hidden_dim
*
torch
.
finfo
(
dtype
).
bits
)
/
(
8
*
1024
*
1024
)
all_results
.
append
(
{
"num_tokens"
:
num_tokens
,
"hidden_dim"
:
args
.
hidden_dim
,
"dtype"
:
str
(
dtype
).
replace
(
"torch."
,
""
),
"use_residual"
:
use_residual
,
"quant_modes"
:
sorted
(
list
(
quant_modes
)),
"input_size_mb"
:
input_size_mb
,
"results"
:
results
,
}
)
print_results
(
results
,
num_tokens
,
args
.
hidden_dim
,
dtype
,
use_residual
,
quant_modes
,
input_size_mb
,
)
# Save results to markdown file
if
args
.
output_file
and
rank
==
0
:
save_results_to_file
(
all_results
,
world_size
,
args
,
rank
)
finally
:
# Cleanup
if
ipc_handles
is
not
None
:
cleanup_flashinfer_workspace
(
ipc_handles
)
dist
.
barrier
()
if
__name__
==
"__main__"
:
main
()
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
View file @
006693ed
...
@@ -13,11 +13,11 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
...
@@ -13,11 +13,11 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts
,
fused_experts
,
fused_topk
,
fused_topk
,
)
)
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
[
DEFAULT_MODELS
=
[
"
nm-testing
/Mixtral-8x7B-Instruct-v0.1"
,
"
mistralai
/Mixtral-8x7B-Instruct-v0.1"
,
"
nm-testing/d
eep
s
eek
v
2-
l
ite"
,
"
deepseek-ai/D
eep
S
eek
-V
2-
L
ite"
,
"ibm-granite/granite-3.0-1b-a400m"
,
"ibm-granite/granite-3.0-1b-a400m"
,
"ibm-granite/granite-3.0-3b-a800m"
,
"ibm-granite/granite-3.0-3b-a800m"
,
]
]
...
...
Prev
1
2
3
4
5
6
7
8
9
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment