Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
41199996
Commit
41199996
authored
Dec 13, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.12.0' into v0.12.0-dev
parents
31021d81
4fd9d6a8
Changes
380
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1847 additions
and
270 deletions
+1847
-270
benchmarks/kernels/bench_per_token_quant_fp8.py
benchmarks/kernels/bench_per_token_quant_fp8.py
+3
-2
benchmarks/kernels/benchmark_activation.py
benchmarks/kernels/benchmark_activation.py
+2
-1
benchmarks/kernels/benchmark_bitblas.py
benchmarks/kernels/benchmark_bitblas.py
+1
-1
benchmarks/kernels/benchmark_cutlass_fp4_moe.py
benchmarks/kernels/benchmark_cutlass_fp4_moe.py
+1
-1
benchmarks/kernels/benchmark_cutlass_moe_fp8.py
benchmarks/kernels/benchmark_cutlass_moe_fp8.py
+3
-3
benchmarks/kernels/benchmark_device_communicators.py
benchmarks/kernels/benchmark_device_communicators.py
+4
-4
benchmarks/kernels/benchmark_fused_collective.py
benchmarks/kernels/benchmark_fused_collective.py
+1129
-0
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
+3
-3
benchmarks/kernels/benchmark_layernorm.py
benchmarks/kernels/benchmark_layernorm.py
+2
-1
benchmarks/kernels/benchmark_lora.py
benchmarks/kernels/benchmark_lora.py
+457
-40
benchmarks/kernels/benchmark_machete.py
benchmarks/kernels/benchmark_machete.py
+17
-17
benchmarks/kernels/benchmark_marlin.py
benchmarks/kernels/benchmark_marlin.py
+3
-3
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+27
-10
benchmarks/kernels/benchmark_moe_permute_unpermute.py
benchmarks/kernels/benchmark_moe_permute_unpermute.py
+6
-6
benchmarks/kernels/benchmark_mrope.py
benchmarks/kernels/benchmark_mrope.py
+8
-13
benchmarks/kernels/benchmark_paged_attention.py
benchmarks/kernels/benchmark_paged_attention.py
+4
-6
benchmarks/kernels/benchmark_per_token_group_quant.py
benchmarks/kernels/benchmark_per_token_group_quant.py
+3
-3
benchmarks/kernels/benchmark_polynorm.py
benchmarks/kernels/benchmark_polynorm.py
+0
-155
benchmarks/kernels/benchmark_quant.py
benchmarks/kernels/benchmark_quant.py
+2
-1
benchmarks/kernels/benchmark_reshape_and_cache.py
benchmarks/kernels/benchmark_reshape_and_cache.py
+172
-0
No files found.
Too many changes to show.
To preserve performance only
380 of 380+
files are displayed.
Plain diff
Email patch
benchmarks/kernels/bench_per_token_quant_fp8.py
View file @
41199996
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
from
typing
import
Callable
from
collections.abc
import
Callable
from
unittest.mock
import
patch
import
pandas
as
pd
...
...
@@ -10,7 +10,8 @@ import torch
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.triton_utils
import
triton
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
def
with_triton_mode
(
fn
):
...
...
benchmarks/kernels/benchmark_activation.py
View file @
41199996
...
...
@@ -10,7 +10,8 @@ import vllm.model_executor.layers.activation # noqa F401
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
triton
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
batch_size_range
=
[
1
,
16
,
32
,
64
,
128
]
seq_len_range
=
[
1
,
16
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
]
...
...
benchmarks/kernels/benchmark_bitblas.py
View file @
41199996
...
...
@@ -28,7 +28,7 @@ except ImportError as e:
from
bitblas
import
Matmul
,
MatmulConfig
,
auto_detect_nvidia_target
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark BitBLAS int4 on a specific target."
...
...
benchmarks/kernels/benchmark_cutlass_fp4_moe.py
View file @
41199996
...
...
@@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
cutlass_moe_fp4
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
,
fused_topk
from
vllm.scalar_type
import
scalar_types
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
WEIGHT_SHAPES_MOE
=
{
"nvidia/DeepSeek-R1-FP4"
:
[
...
...
benchmarks/kernels/benchmark_cutlass_moe_fp8.py
View file @
41199996
...
...
@@ -14,7 +14,7 @@ from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_confi
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
cutlass_moe_fp8
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
,
fused_topk
from
vllm.platforms
import
current_platform
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
# Weight shapes for different models: [num_experts, topk, hidden_size,
# intermediate_size]
...
...
@@ -255,8 +255,8 @@ def bench_run(
torch
.
cuda
.
synchronize
()
# Timing
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
start_event
=
torch
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
Event
(
enable_timing
=
True
)
latencies
=
[]
for
_
in
range
(
num_iters
):
...
...
benchmarks/kernels/benchmark_device_communicators.py
View file @
41199996
...
...
@@ -22,8 +22,8 @@ Example:
import
json
import
os
import
time
from
collections.abc
import
Callable
from
contextlib
import
nullcontext
from
typing
import
Callable
,
Optional
import
torch
import
torch.distributed
as
dist
...
...
@@ -39,7 +39,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import (
)
from
vllm.distributed.device_communicators.symm_mem
import
SymmMemCommunicator
from
vllm.logger
import
init_logger
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
logger
=
init_logger
(
__name__
)
...
...
@@ -264,12 +264,12 @@ class CommunicatorBenchmark:
def
benchmark_allreduce_single
(
self
,
sequence_length
:
int
,
allreduce_fn
:
Callable
[[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]
],
allreduce_fn
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
|
None
],
should_use_fn
:
Callable
[[
torch
.
Tensor
],
bool
],
context
,
num_warmup
:
int
,
num_trials
:
int
,
)
->
Optional
[
float
]
:
)
->
float
|
None
:
"""Benchmark method with CUDA graph optimization."""
try
:
# Create test tensor (2D: sequence_length x hidden_size)
...
...
benchmarks/kernels/benchmark_fused_collective.py
0 → 100644
View file @
41199996
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark for FlashInfer fused collective operations vs standard operations.
This benchmark compares:
1. FlashInfer's trtllm_allreduce_fusion (fused allreduce + rmsnorm + optional quant)
2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations
Usage with torchrun:
torchrun --nproc_per_node=2 benchmark_fused_collective.py
"""
import
argparse
import
itertools
import
os
import
time
import
pandas
as
pd
import
torch
# type: ignore
import
torch.distributed
as
dist
# type: ignore
from
vllm.config.vllm
import
CompilationConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.distributed
import
(
get_tp_group
,
tensor_model_parallel_all_reduce
,
)
from
vllm.distributed.parallel_state
import
(
graph_capture
,
init_distributed_environment
,
initialize_model_parallel
,
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.layernorm
import
RMSNorm
# noqa
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
# noqa
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
# noqa
from
vllm.platforms
import
current_platform
# noqa
RMS_NORM_OP
=
torch
.
ops
.
_C
.
rms_norm
FUSED_ADD_RMS_NORM_OP
=
torch
.
ops
.
_C
.
fused_add_rms_norm
RMS_NORM_STATIC_FP8_QUANT_OP
=
torch
.
ops
.
_C
.
rms_norm_static_fp8_quant
FUSED_ADD_RMS_NORM_STATIC_FP8_QUANT_OP
=
(
torch
.
ops
.
_C
.
fused_add_rms_norm_static_fp8_quant
)
SCALED_FP4_QUANT_OP
=
torch
.
ops
.
_C
.
scaled_fp4_quant
logger
=
init_logger
(
__name__
)
# Try to import FlashInfer
try
:
import
flashinfer.comm
as
flashinfer_comm
# type: ignore
if
not
hasattr
(
flashinfer_comm
,
"trtllm_allreduce_fusion"
):
flashinfer_comm
=
None
logger
.
warning
(
"FlashInfer comm module found but missing trtllm_allreduce_fusion"
)
except
ImportError
:
flashinfer_comm
=
None
logger
.
warning
(
"FlashInfer not found, only benchmarking standard operations"
)
# Constants
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
MiB
=
1024
*
1024
# FlashInfer max sizes per world size
# Enable 64MB for 2, 4, 8 world sizes to verify large input sizes
# use --disable-oneshot to disable oneshot mode for very large input sizes
_FI_MAX_SIZES
=
{
2
:
64
*
MiB
,
# 64MB
4
:
64
*
MiB
,
# 64MB
8
:
64
*
MiB
,
# 64MB
}
# Global workspace tensor for FlashInfer
_FI_WORKSPACE_TENSOR
=
None
def
setup_flashinfer_workspace
(
world_size
:
int
,
rank
:
int
,
hidden_dim
:
int
,
max_token_num
:
int
,
use_fp32_lamport
:
bool
=
False
,
):
"""Setup FlashInfer workspace for fused allreduce operations."""
global
_FI_WORKSPACE_TENSOR
if
flashinfer_comm
is
None
:
return
None
,
None
if
world_size
not
in
_FI_MAX_SIZES
:
logger
.
warning
(
"FlashInfer not supported for world size %s"
,
world_size
)
return
None
,
None
try
:
# Create IPC workspace
ipc_handles
,
workspace_tensor
=
(
flashinfer_comm
.
trtllm_create_ipc_workspace_for_all_reduce_fusion
(
tp_rank
=
rank
,
tp_size
=
world_size
,
max_token_num
=
max_token_num
,
hidden_dim
=
hidden_dim
,
group
=
get_tp_group
().
device_group
,
use_fp32_lamport
=
use_fp32_lamport
,
)
)
_FI_WORKSPACE_TENSOR
=
workspace_tensor
return
ipc_handles
,
workspace_tensor
except
Exception
as
e
:
logger
.
error
(
"Failed to setup FlashInfer workspace: %s"
,
e
)
return
None
,
None
def
cleanup_flashinfer_workspace
(
ipc_handles
):
"""Cleanup FlashInfer workspace."""
if
flashinfer_comm
is
None
or
ipc_handles
is
None
:
return
try
:
group
=
get_tp_group
().
device_group
flashinfer_comm
.
trtllm_destroy_ipc_workspace_for_all_reduce
(
ipc_handles
,
group
)
except
Exception
as
e
:
logger
.
error
(
"Failed to cleanup FlashInfer workspace: %s"
,
e
)
class
FlashInferFusedAllReduceParams
:
"""Parameters for FlashInfer fused allreduce operations."""
def
__init__
(
self
,
rank
:
int
,
world_size
:
int
,
use_fp32_lamport
:
bool
=
False
,
max_token_num
:
int
=
1024
,
):
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
use_fp32_lamport
=
use_fp32_lamport
self
.
trigger_completion_at_end
=
True
self
.
launch_with_pdl
=
True
self
.
fp32_acc
=
True
self
.
max_token_num
=
max_token_num
def
get_trtllm_fused_allreduce_kwargs
(
self
):
return
{
"world_rank"
:
self
.
rank
,
"world_size"
:
self
.
world_size
,
"launch_with_pdl"
:
self
.
launch_with_pdl
,
"trigger_completion_at_end"
:
self
.
trigger_completion_at_end
,
"fp32_acc"
:
self
.
fp32_acc
,
}
def
flashinfer_fused_allreduce_rmsnorm
(
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
rms_gamma
:
torch
.
Tensor
,
rms_eps
:
float
,
allreduce_params
:
"FlashInferFusedAllReduceParams"
,
use_oneshot
:
bool
,
norm_out
:
torch
.
Tensor
|
None
=
None
,
):
"""FlashInfer fused allreduce + rmsnorm operation."""
if
flashinfer_comm
is
None
or
_FI_WORKSPACE_TENSOR
is
None
:
raise
RuntimeError
(
"FlashInfer not available or workspace not initialized"
)
if
norm_out
is
None
:
norm_out
=
input_tensor
residual_out
=
residual
else
:
residual_out
=
input_tensor
flashinfer_comm
.
trtllm_allreduce_fusion
(
allreduce_in
=
input_tensor
,
token_num
=
input_tensor
.
shape
[
0
],
residual_in
=
residual
,
residual_out
=
residual_out
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
hidden_dim
=
input_tensor
.
shape
[
-
1
],
workspace_ptrs
=
_FI_WORKSPACE_TENSOR
,
pattern_code
=
flashinfer_comm
.
AllReduceFusionPattern
.
kARResidualRMSNorm
,
allreduce_out
=
None
,
quant_out
=
None
,
scale_out
=
None
,
layout_code
=
flashinfer_comm
.
QuantizationSFLayout
.
SWIZZLED_128x4
,
scale_factor
=
None
,
use_oneshot
=
use_oneshot
,
**
allreduce_params
.
get_trtllm_fused_allreduce_kwargs
(),
)
def
flashinfer_fused_allreduce_rmsnorm_fp8_quant
(
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
rms_gamma
:
torch
.
Tensor
,
rms_eps
:
float
,
scale_factor
:
torch
.
Tensor
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
use_oneshot
:
bool
=
True
,
norm_out
:
torch
.
Tensor
|
None
=
None
,
quant_out
:
torch
.
Tensor
|
None
=
None
,
):
"""FlashInfer fused allreduce + rmsnorm + FP8 quantization."""
if
flashinfer_comm
is
None
or
_FI_WORKSPACE_TENSOR
is
None
:
raise
RuntimeError
(
"FlashInfer not available or workspace not initialized"
)
if
norm_out
is
None
:
norm_out
=
input_tensor
residual_out
=
residual
else
:
residual_out
=
input_tensor
flashinfer_comm
.
trtllm_allreduce_fusion
(
allreduce_in
=
input_tensor
,
token_num
=
input_tensor
.
shape
[
0
],
residual_in
=
residual
,
residual_out
=
residual_out
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
hidden_dim
=
input_tensor
.
shape
[
-
1
],
workspace_ptrs
=
_FI_WORKSPACE_TENSOR
,
pattern_code
=
flashinfer_comm
.
AllReduceFusionPattern
.
kARResidualRMSNormFP8Quant
,
allreduce_out
=
None
,
quant_out
=
quant_out
,
scale_out
=
None
,
layout_code
=
flashinfer_comm
.
QuantizationSFLayout
.
SWIZZLED_128x4
,
scale_factor
=
scale_factor
,
use_oneshot
=
use_oneshot
,
**
allreduce_params
.
get_trtllm_fused_allreduce_kwargs
(),
)
def
flashinfer_fused_allreduce_rmsnorm_fp4_quant
(
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
rms_gamma
:
torch
.
Tensor
,
rms_eps
:
float
,
input_global_scale
:
torch
.
Tensor
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
quant_out
:
torch
.
Tensor
,
use_oneshot
:
bool
,
output_scale
:
torch
.
Tensor
,
norm_out
:
torch
.
Tensor
|
None
=
None
,
):
"""FlashInfer fused allreduce + rmsnorm + FP4 quantization."""
if
flashinfer_comm
is
None
or
_FI_WORKSPACE_TENSOR
is
None
:
raise
RuntimeError
(
"FlashInfer not available or workspace not initialized"
)
if
norm_out
is
None
:
norm_out
=
input_tensor
residual_out
=
residual
else
:
residual_out
=
input_tensor
flashinfer_comm
.
trtllm_allreduce_fusion
(
allreduce_in
=
input_tensor
,
token_num
=
input_tensor
.
shape
[
0
],
residual_in
=
residual
,
residual_out
=
residual_out
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
hidden_dim
=
input_tensor
.
shape
[
-
1
],
workspace_ptrs
=
_FI_WORKSPACE_TENSOR
,
pattern_code
=
flashinfer_comm
.
AllReduceFusionPattern
.
kARResidualRMSNormFP4Quant
,
allreduce_out
=
None
,
quant_out
=
quant_out
,
scale_out
=
output_scale
,
layout_code
=
flashinfer_comm
.
QuantizationSFLayout
.
SWIZZLED_128x4
,
scale_factor
=
input_global_scale
,
use_oneshot
=
use_oneshot
,
**
allreduce_params
.
get_trtllm_fused_allreduce_kwargs
(),
)
class
VllmFusedAllreduce
:
def
__init__
(
self
,
hidden_dim
,
dtype
):
self
.
rms_eps
=
1e-6
self
.
rms_norm
=
RMSNorm
(
hidden_dim
,
eps
=
self
.
rms_eps
,
dtype
=
dtype
)
self
.
fp8_quant
=
QuantFP8
(
static
=
True
,
group_shape
=
GroupShape
.
PER_TENSOR
,
)
def
allreduce_rmsnorm
(
self
,
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
):
allreduce_out
=
tensor_model_parallel_all_reduce
(
input_tensor
)
return
self
.
rms_norm
(
allreduce_out
,
residual
)
def
allreduce_rmsnorm_fp8_quant
(
self
,
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
scale_factor
:
torch
.
Tensor
,
):
allreduce_out
=
tensor_model_parallel_all_reduce
(
input_tensor
)
rms_out
=
self
.
rms_norm
(
allreduce_out
,
residual
)
if
residual
is
None
:
quant_out
=
self
.
fp8_quant
(
rms_out
,
scale_factor
)
return
quant_out
else
:
rms_out
,
residual_out
=
rms_out
quant_out
=
self
.
fp8_quant
(
rms_out
,
scale_factor
)
return
quant_out
,
residual_out
def
allreduce_rmsnorm_fp4_quant
(
self
,
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
input_global_scale
:
torch
.
Tensor
,
quant_out
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
):
allreduce_out
=
tensor_model_parallel_all_reduce
(
input_tensor
)
rms_out
=
self
.
rms_norm
(
allreduce_out
,
residual
)
if
residual
is
None
:
SCALED_FP4_QUANT_OP
(
quant_out
,
rms_out
,
output_scale
,
input_global_scale
)
return
quant_out
,
output_scale
else
:
rms_out
,
residual_out
=
rms_out
SCALED_FP4_QUANT_OP
(
quant_out
,
rms_out
,
output_scale
,
input_global_scale
)
return
quant_out
,
residual_out
,
output_scale
def
create_test_tensors
(
num_tokens
:
int
,
hidden_dim
:
int
,
dtype
:
torch
.
dtype
,
use_residual
:
bool
=
True
):
"""Create test tensors for benchmarking."""
input_tensor
=
torch
.
randn
(
num_tokens
,
hidden_dim
,
dtype
=
dtype
)
residual
=
(
torch
.
randn_like
(
input_tensor
)
if
use_residual
else
torch
.
zeros_like
(
input_tensor
)
)
rms_gamma
=
torch
.
ones
(
hidden_dim
,
dtype
=
dtype
)
norm_out
=
None
if
use_residual
else
torch
.
empty_like
(
input_tensor
)
# Quantization scales
scale_fp8
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
scale_fp4
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
quant_out_fp8
=
torch
.
empty_like
(
input_tensor
,
dtype
=
FP8_DTYPE
)
# Pre-allocate FP4 output tensors (to avoid allocation overhead in benchmarks)
fp4_quant_out
=
torch
.
empty
((
num_tokens
,
hidden_dim
//
2
),
dtype
=
torch
.
uint8
)
fp4_output_scale
=
torch
.
empty
((
128
,
4
),
dtype
=
torch
.
int32
)
return
(
input_tensor
,
norm_out
,
residual
,
rms_gamma
,
scale_fp8
,
quant_out_fp8
,
scale_fp4
,
fp4_quant_out
,
fp4_output_scale
,
)
def
benchmark_operation
(
operation_func
,
*
args
,
warmup
:
int
=
5
,
trials
:
int
=
20
,
**
kwargs
):
"""Benchmark a single operation using CUDA graphs."""
# Warmup before graph capture
for
_
in
range
(
warmup
):
operation_func
(
*
args
,
**
kwargs
)
torch
.
cuda
.
synchronize
()
# Create CUDA graph
graph
=
torch
.
cuda
.
CUDAGraph
()
num_op_per_cudagraph
=
10
# Use vLLM's graph_capture to make tensor_model_parallel_all_reduce graph-safe
device
=
torch
.
device
(
f
"cuda:
{
torch
.
cuda
.
current_device
()
}
"
)
with
graph_capture
(
device
=
device
),
torch
.
cuda
.
graph
(
graph
):
for
_
in
range
(
num_op_per_cudagraph
):
operation_func
(
*
args
,
**
kwargs
)
# Graph warmup
torch
.
cuda
.
synchronize
()
for
_
in
range
(
warmup
):
graph
.
replay
()
# Benchmark with CUDA graph
torch
.
cuda
.
synchronize
()
start_time
=
time
.
perf_counter
()
for
_
in
range
(
trials
//
num_op_per_cudagraph
):
# operation_func(*args, **kwargs)
graph
.
replay
()
torch
.
cuda
.
synchronize
()
end_time
=
time
.
perf_counter
()
avg_time_ms
=
((
end_time
-
start_time
)
/
trials
)
*
1000
return
avg_time_ms
def
run_benchmarks
(
num_tokens
:
int
,
hidden_dim
:
int
,
dtype
:
torch
.
dtype
,
use_residual
:
bool
,
allreduce_params
:
FlashInferFusedAllReduceParams
|
None
,
quant_modes
:
set
[
str
],
no_oneshot
:
bool
,
):
"""Run all benchmarks for given configuration.
Args:
quant_mode: "none", "fp8_only", "fp4_only", or "all"
"""
(
input_tensor
,
norm_out
,
residual
,
rms_gamma
,
scale_fp8
,
quant_out_fp8
,
scale_fp4
,
fp4_quant_out
,
fp4_output_scale
,
)
=
create_test_tensors
(
num_tokens
,
hidden_dim
,
dtype
,
use_residual
)
rms_eps
=
1e-6
results
=
{}
vllm_fused_allreduce
=
VllmFusedAllreduce
(
hidden_dim
,
dtype
)
use_oneshot_options
=
[
False
]
if
no_oneshot
else
[
True
,
False
]
# Create RMSNorm and QuantFP8 layers once for native benchmarks
if
"none"
in
quant_modes
:
# Standard AllReduce + RMSNorm
for
custom_op
in
[
"-rms_norm"
,
"+rms_norm"
]:
with
set_current_vllm_config
(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
custom_op
]))
):
try
:
suffix
=
(
"_custom_rms_norm"
if
"+"
in
custom_op
else
"_native_rms_norm"
)
time_ms
=
benchmark_operation
(
vllm_fused_allreduce
.
allreduce_rmsnorm
,
input_tensor
,
residual
=
residual
,
)
results
[
f
"standard_allreduce_
{
suffix
}
"
]
=
time_ms
except
Exception
as
e
:
logger
.
error
(
"Standard AllReduce+RMSNorm failed: %s"
,
e
)
results
[
f
"standard_allreduce_
{
suffix
}
"
]
=
float
(
"inf"
)
# Standard AllReduce + RMSNorm Native Compiled
with
set_current_vllm_config
(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
"-rms_norm"
]))
):
try
:
standard_allreduce_rmsnorm_native_compiled
=
torch
.
compile
(
vllm_fused_allreduce
.
allreduce_rmsnorm
,
fullgraph
=
True
,
dynamic
=
False
,
)
time_ms
=
benchmark_operation
(
standard_allreduce_rmsnorm_native_compiled
,
input_tensor
,
residual
=
residual
,
)
results
[
"standard_allreduce_rmsnorm_native_compiled"
]
=
time_ms
except
Exception
as
e
:
logger
.
error
(
"Standard AllReduce+RMSNorm Native Compiled failed: %s"
,
e
)
results
[
"standard_allreduce_rmsnorm_native_compiled"
]
=
float
(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm Oneshot/Twoshot
if
flashinfer_comm
is
not
None
and
allreduce_params
is
not
None
:
for
use_oneshot
in
use_oneshot_options
:
suffix
=
"_oneshot"
if
use_oneshot
else
"_twoshot"
try
:
time_ms
=
benchmark_operation
(
flashinfer_fused_allreduce_rmsnorm
,
input_tensor
,
residual
=
residual
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
allreduce_params
=
allreduce_params
,
use_oneshot
=
use_oneshot
,
)
results
[
f
"flashinfer_fused_allreduce_rmsnorm
{
suffix
}
"
]
=
time_ms
except
Exception
as
e
:
logger
.
error
(
"FlashInfer Fused AllReduce+RMSNorm failed: %s"
,
e
)
results
[
f
"flashinfer_fused_allreduce_rmsnorm
{
suffix
}
"
]
=
float
(
"inf"
)
if
"fp8"
in
quant_modes
:
# Standard AllReduce + RMSNorm + FP8 Quant
for
rms_norm_custom_op
in
[
"-rms_norm"
,
"+rms_norm"
]:
suffix
=
(
"_custom_rms_norm"
if
"+"
in
rms_norm_custom_op
else
"_native_rms_norm"
)
for
quant_fp8_custom_op
in
[
"-quant_fp8"
,
"+quant_fp8"
]:
suffix
+=
(
"_custom_quant_fp8"
if
"+"
in
quant_fp8_custom_op
else
"_native_quant_fp8"
)
with
set_current_vllm_config
(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
rms_norm_custom_op
,
quant_fp8_custom_op
]
)
)
):
try
:
time_ms
=
benchmark_operation
(
vllm_fused_allreduce
.
allreduce_rmsnorm_fp8_quant
,
input_tensor
,
residual
=
residual
,
scale_factor
=
scale_fp8
,
)
results
[
f
"standard_allreduce
{
suffix
}
"
]
=
time_ms
except
Exception
as
e
:
logger
.
error
(
"Standard AllReduce+RMSNorm+FP8 failed: %s"
,
e
)
results
[
f
"standard_allreduce
{
suffix
}
"
]
=
float
(
"inf"
)
# Standard AllReduce + RMSNorm + FP8 Quant Native Compiled
with
set_current_vllm_config
(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
"-rms_norm"
,
"-quant_fp8"
]
)
)
):
try
:
standard_allreduce_rmsnorm_fp8_quant_native_compiled
=
torch
.
compile
(
vllm_fused_allreduce
.
allreduce_rmsnorm_fp8_quant
,
fullgraph
=
True
,
dynamic
=
False
,
)
time_ms
=
benchmark_operation
(
standard_allreduce_rmsnorm_fp8_quant_native_compiled
,
input_tensor
,
residual
=
residual
,
scale_factor
=
scale_fp8
,
)
results
[
"standard_allreduce_rmsnorm_fp8_quant_native_compiled"
]
=
(
time_ms
)
except
Exception
as
e
:
logger
.
error
(
"Standard AllReduce+RMSNorm+FP8 Native Compiled failed: %s"
,
e
)
results
[
"standard_allreduce_rmsnorm_fp8_quant_native_compiled"
]
=
float
(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot
if
flashinfer_comm
is
not
None
and
allreduce_params
is
not
None
:
for
use_oneshot
in
use_oneshot_options
:
suffix
=
"_oneshot"
if
use_oneshot
else
"_twoshot"
try
:
time_ms
=
benchmark_operation
(
flashinfer_fused_allreduce_rmsnorm_fp8_quant
,
input_tensor
,
norm_out
=
norm_out
,
residual
=
residual
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
scale_factor
=
scale_fp8
,
quant_out
=
quant_out_fp8
,
allreduce_params
=
allreduce_params
,
use_oneshot
=
use_oneshot
,
)
results
[
f
"flashinfer_fused_allreduce_rmsnorm_fp8_quant
{
suffix
}
"
]
=
(
time_ms
)
except
Exception
as
e
:
logger
.
error
(
"FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s"
,
e
,
)
results
[
f
"flashinfer_fused_allreduce_rmsnorm_fp8_quant
{
suffix
}
"
]
=
(
float
(
"inf"
)
)
if
"fp4"
in
quant_modes
and
current_platform
.
has_device_capability
(
100
):
# Standard AllReduce + RMSNorm + FP4 Quant
for
rms_norm_custom_op
in
[
"-rms_norm"
,
"+rms_norm"
]:
suffix
=
(
"_custom_rms_norm"
if
"+"
in
rms_norm_custom_op
else
"_native_rms_norm"
)
with
set_current_vllm_config
(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
rms_norm_custom_op
]
)
)
):
try
:
time_ms
=
benchmark_operation
(
vllm_fused_allreduce
.
allreduce_rmsnorm_fp4_quant
,
input_tensor
,
residual
=
residual
,
input_global_scale
=
scale_fp4
,
quant_out
=
fp4_quant_out
,
output_scale
=
fp4_output_scale
,
)
results
[
f
"standard_allreduce_
{
suffix
}
_fp4_quant"
]
=
time_ms
except
Exception
as
e
:
logger
.
error
(
"Standard AllReduce+RMSNorm+FP4 failed: %s"
,
e
)
results
[
f
"standard_allreduce_
{
suffix
}
_fp4_quant"
]
=
float
(
"inf"
)
# Standard AllReduce + RMSNorm + FP4 Quant Native Compiled
with
set_current_vllm_config
(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
"-rms_norm"
]))
):
try
:
standard_allreduce_rmsnorm_fp4_quant_native_compiled
=
torch
.
compile
(
vllm_fused_allreduce
.
allreduce_rmsnorm_fp4_quant
,
fullgraph
=
True
,
dynamic
=
False
,
)
time_ms
=
benchmark_operation
(
standard_allreduce_rmsnorm_fp4_quant_native_compiled
,
input_tensor
,
residual
=
residual
,
quant_out
=
fp4_quant_out
,
input_global_scale
=
scale_fp4
,
output_scale
=
fp4_output_scale
,
)
results
[
"standard_allreduce_rmsnorm_fp4_quant_native_compiled"
]
=
(
time_ms
)
except
Exception
as
e
:
logger
.
error
(
"Standard AllReduce+RMSNorm+FP4 Native Compiled failed: %s"
,
e
)
results
[
"standard_allreduce_rmsnorm_fp4_quant_native_compiled"
]
=
float
(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot
if
flashinfer_comm
is
not
None
and
allreduce_params
is
not
None
:
for
use_oneshot
in
use_oneshot_options
:
suffix
=
"_oneshot"
if
use_oneshot
else
"_twoshot"
try
:
time_ms
=
benchmark_operation
(
flashinfer_fused_allreduce_rmsnorm_fp4_quant
,
input_tensor
,
residual
=
residual
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
input_global_scale
=
scale_fp4
,
allreduce_params
=
allreduce_params
,
quant_out
=
fp4_quant_out
,
output_scale
=
fp4_output_scale
,
use_oneshot
=
use_oneshot
,
)
results
[
f
"flashinfer_fused_allreduce_rmsnorm_fp4_quant
{
suffix
}
"
]
=
(
time_ms
)
except
Exception
as
e
:
logger
.
error
(
"FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s"
,
e
,
)
results
[
f
"flashinfer_fused_allreduce_rmsnorm_fp4_quant
{
suffix
}
"
]
=
(
float
(
"inf"
)
)
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Two-shot
if
flashinfer_comm
is
not
None
and
allreduce_params
is
not
None
:
try
:
time_ms
=
benchmark_operation
(
flashinfer_fused_allreduce_rmsnorm_fp4_quant
,
input_tensor
,
residual
=
residual
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
input_global_scale
=
scale_fp4
,
allreduce_params
=
allreduce_params
,
quant_out
=
fp4_quant_out
,
output_scale
=
fp4_output_scale
,
use_oneshot
=
False
,
)
results
[
"flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"
]
=
(
time_ms
)
except
Exception
as
e
:
logger
.
error
(
"FlashInfer Fused AllReduce+RMSNorm+FP4 Two-shot failed: %s"
,
e
,
)
results
[
"flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"
]
=
float
(
"inf"
)
return
results
def
prepare_results_with_speedups
(
results_dict
):
"""Prepare results with speedup calculations based on dynamic baseline selection."""
prepared_results
=
[]
# Determine the fastest baseline for each operation type
def
get_fastest_baseline
(
op_name
,
results_dict
):
"""Get the fastest baseline between standard and native_compiled versions."""
if
"fp8_quant"
in
op_name
:
candidates
=
[
"standard_allreduce_rmsnorm_fp8_quant"
,
"standard_allreduce_rmsnorm_fp8_quant_native_compiled"
,
]
elif
"fp4_quant"
in
op_name
:
candidates
=
[
"standard_allreduce_rmsnorm_fp4_quant"
,
"standard_allreduce_rmsnorm_fp4_quant_native_compiled"
,
]
else
:
candidates
=
[
"standard_allreduce_rmsnorm"
,
"standard_allreduce_rmsnorm_native_compiled"
,
]
# Find the fastest among available candidates
fastest_time
=
float
(
"inf"
)
fastest_baseline
=
None
for
candidate
in
candidates
:
if
(
candidate
in
results_dict
and
results_dict
[
candidate
]
!=
float
(
"inf"
)
and
results_dict
[
candidate
]
<
fastest_time
):
fastest_time
=
results_dict
[
candidate
]
fastest_baseline
=
candidate
return
fastest_baseline
# Create dynamic baseline mapping
dynamic_baseline_mapping
=
{}
for
op_name
in
results_dict
:
if
(
op_name
.
startswith
(
"flashinfer_"
)
or
op_name
.
startswith
(
"standard_"
)
and
not
op_name
.
endswith
(
"_native_compiled"
)
):
dynamic_baseline_mapping
[
op_name
]
=
get_fastest_baseline
(
op_name
,
results_dict
)
for
op_name
,
time_ms
in
results_dict
.
items
():
if
time_ms
==
float
(
"inf"
):
speedup_str
=
"FAILED"
time_str
=
"FAILED"
else
:
time_str
=
f
"
{
time_ms
:.
3
f
}
"
# Find the appropriate baseline for this operation
baseline_op
=
dynamic_baseline_mapping
.
get
(
op_name
)
if
baseline_op
and
baseline_op
in
results_dict
:
baseline_time
=
results_dict
[
baseline_op
]
if
baseline_time
!=
float
(
"inf"
)
and
baseline_time
>
0
:
speedup
=
baseline_time
/
time_ms
speedup_str
=
f
"
{
speedup
:.
2
f
}
x"
else
:
speedup_str
=
"N/A"
else
:
# For baseline operations, determine if this is the fastest baseline
if
op_name
.
endswith
(
"_native_compiled"
)
or
(
op_name
.
startswith
(
"standard_"
)
and
not
op_name
.
endswith
(
"_native_compiled"
)
):
fastest_baseline
=
get_fastest_baseline
(
op_name
,
results_dict
)
if
fastest_baseline
==
op_name
:
speedup_str
=
"baseline"
else
:
if
fastest_baseline
and
fastest_baseline
in
results_dict
:
baseline_time
=
results_dict
[
fastest_baseline
]
if
baseline_time
!=
float
(
"inf"
)
and
baseline_time
>
0
:
speedup
=
baseline_time
/
time_ms
speedup_str
=
f
"
{
speedup
:.
2
f
}
x"
else
:
speedup_str
=
"N/A"
else
:
speedup_str
=
"N/A"
else
:
speedup_str
=
"N/A"
prepared_results
.
append
(
{
"operation"
:
op_name
,
"time_ms"
:
time_ms
,
"time_str"
:
time_str
,
"speedup_str"
:
speedup_str
,
}
)
return
prepared_results
def
print_results
(
results_dict
,
num_tokens
,
hidden_dim
,
dtype
,
use_residual
,
quant_modes
,
input_size_mb
,
):
"""Print benchmark results in a formatted table."""
print
(
f
"
\n
{
'='
*
80
}
"
)
print
(
f
"Results: num_tokens=
{
num_tokens
}
, hidden_dim=
{
hidden_dim
}
"
f
"(input size:
{
input_size_mb
:.
2
f
}
MB)"
)
print
(
f
"dtype=
{
dtype
}
, residual=
{
'yes'
if
use_residual
else
'no'
}
, "
f
"quant_modes=
{
','
.
join
(
sorted
(
list
(
quant_modes
)))
}
"
)
print
(
f
"
{
'='
*
80
}
"
)
print
(
f
"
{
'Operation'
:
<
50
}
{
'Time (ms)'
:
<
12
}
{
'Speedup'
:
<
10
}
"
)
print
(
f
"
{
'-'
*
80
}
"
)
# Prepare results with speedup calculations
prepared_results
=
prepare_results_with_speedups
(
results_dict
)
for
result
in
prepared_results
:
if
result
[
"time_ms"
]
==
float
(
"inf"
):
time_display
=
result
[
"time_str"
]
else
:
time_display
=
f
"
{
result
[
'time_ms'
]:.
3
f
}
"
print
(
f
"
{
result
[
'operation'
]:
<
50
}
{
time_display
:
<
12
}
{
result
[
'speedup_str'
]:
<
10
}
"
)
def
format_results_markdown
(
all_results
:
list
[
dict
],
world_size
:
int
,
args
:
argparse
.
Namespace
)
->
str
:
"""Format all benchmark results as markdown."""
lines
:
list
[
str
]
=
[]
lines
.
append
(
"# FlashInfer Fused Collective Operations Benchmark Results"
)
lines
.
append
(
""
)
lines
.
append
(
f
"**World Size:**
{
world_size
}
"
)
lines
.
append
(
f
"**Hidden Dimension:**
{
args
.
hidden_dim
}
"
)
lines
.
append
(
f
"**Warmup Iterations:**
{
args
.
warmup
}
"
)
lines
.
append
(
f
"**Benchmark Trials:**
{
args
.
trials
}
"
)
modes
=
","
.
join
(
all_results
[
0
][
"quant_modes"
])
if
all_results
else
"N/A"
lines
.
append
(
f
"**Quantization Modes:**
{
modes
}
"
)
lines
.
append
(
""
)
lines
.
append
(
"---"
)
lines
.
append
(
""
)
for
entry
in
all_results
:
num_tokens
=
entry
[
"num_tokens"
]
dtype
=
entry
[
"dtype"
]
use_residual
=
entry
[
"use_residual"
]
results_dict
=
entry
[
"results"
]
input_size_mb
=
entry
[
"input_size_mb"
]
residual_str
=
"with residual"
if
use_residual
else
"no residual"
lines
.
append
(
f
"## Configuration: num_tokens=
{
num_tokens
}
, dtype=
{
dtype
}
,
{
residual_str
}
"
)
lines
.
append
(
f
"**Input Size:**
{
input_size_mb
:.
2
f
}
MB"
)
lines
.
append
(
""
)
prepared
=
prepare_results_with_speedups
(
results_dict
)
# Build DataFrame for markdown export
rows
=
[
{
"Operation"
:
r
[
"operation"
].
replace
(
"_"
,
" "
).
title
(),
"Time (ms)"
:
r
[
"time_str"
],
"Speedup"
:
r
[
"speedup_str"
],
}
for
r
in
prepared
]
df
=
pd
.
DataFrame
(
rows
)
if
df
.
empty
:
lines
.
append
(
"No results."
)
else
:
lines
.
append
(
df
.
to_markdown
(
index
=
False
))
lines
.
append
(
""
)
return
"
\n
"
.
join
(
lines
)
def
save_results_to_file
(
all_results
:
list
[
dict
],
world_size
:
int
,
args
:
argparse
.
Namespace
,
rank
:
int
):
"""Save benchmark results to markdown file (only on rank 0)."""
if
rank
!=
0
:
return
if
not
all_results
:
logger
.
warning
(
"No results to save"
)
return
output_path
=
args
.
output_file
try
:
markdown_content
=
format_results_markdown
(
all_results
,
world_size
,
args
)
with
open
(
output_path
,
"a"
)
as
f
:
f
.
write
(
markdown_content
)
except
Exception
as
e
:
logger
.
error
(
"Failed to save results to file: %s"
,
e
)
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Benchmark fused collective operations"
)
parser
.
add_argument
(
"--num-tokens"
,
type
=
int
,
nargs
=
"+"
,
default
=
[
128
,
512
,
1024
,
2048
],
help
=
"Numbers of tokens to test"
,
)
parser
.
add_argument
(
"--hidden-dim"
,
type
=
int
,
default
=
8192
,
help
=
"Hidden dimension size"
)
parser
.
add_argument
(
"--dtypes"
,
type
=
str
,
nargs
=
"+"
,
default
=
[
"bfloat16"
],
choices
=
[
"float16"
,
"bfloat16"
,
"float32"
],
help
=
"Data types to test"
,
)
parser
.
add_argument
(
"--no-residual"
,
action
=
"store_true"
,
help
=
"Skip residual connection tests"
,
)
parser
.
add_argument
(
"--quant-modes"
,
type
=
str
,
default
=
"none,fp8,fp4"
,
help
=
(
"Comma-separated quantization modes to run: none, fp8, fp4. "
"Default: none,fp8,fp4"
),
)
parser
.
add_argument
(
"--warmup"
,
type
=
int
,
default
=
5
,
help
=
"Number of warmup iterations"
)
parser
.
add_argument
(
"--trials"
,
type
=
int
,
default
=
20
,
help
=
"Number of benchmark trials"
)
parser
.
add_argument
(
"--output-file"
,
type
=
str
,
help
=
"""Output file path for markdown results
(default: benchmark_results_<timestamp>.md)
"""
,
)
parser
.
add_argument
(
"--no-oneshot"
,
action
=
"store_true"
,
help
=
"Skip oneshot benchmarks"
,
)
args
=
parser
.
parse_args
()
# Check if running with torchrun (required for collective operations)
if
"RANK"
not
in
os
.
environ
or
"WORLD_SIZE"
not
in
os
.
environ
:
raise
RuntimeError
(
"Must run with torchrun for distributed benchmarking. "
"Example: torchrun --nproc_per_node=2 benchmark_fused_collective.py"
)
# Initialize distributed environment
rank
=
int
(
os
.
environ
[
"RANK"
])
world_size
=
int
(
os
.
environ
[
"WORLD_SIZE"
])
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
init_distributed_environment
()
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
# Validate world size (must be > 1 for collective operations)
if
world_size
<=
1
:
raise
ValueError
(
"World size must be > 1 for collective operations benchmarking. "
f
"Current world size:
{
world_size
}
. Use torchrun with --nproc_per_node > 1."
)
# Parse quantization modes
valid_quant_modes
=
{
"none"
,
"fp8"
,
"fp4"
}
raw_modes
=
[
m
.
strip
().
lower
()
for
m
in
(
args
.
quant_modes
or
""
).
split
(
","
)
if
m
.
strip
()
]
quant_modes
=
set
(
raw_modes
)
if
raw_modes
else
{
"none"
,
"fp8"
,
"fp4"
}
invalid
=
sorted
(
list
(
quant_modes
-
valid_quant_modes
))
if
invalid
:
raise
ValueError
(
f
"Invalid --quant-modes entries:
{
','
.
join
(
invalid
)
}
. "
f
"Valid options are:
{
','
.
join
(
sorted
(
valid_quant_modes
))
}
."
)
if
rank
==
0
:
logger
.
info
(
"Running benchmark with world_size=%s, rank=%s"
,
world_size
,
rank
)
logger
.
info
(
"Quantization modes: %s"
,
","
.
join
(
sorted
(
list
(
quant_modes
))))
if
flashinfer_comm
is
not
None
:
logger
.
info
(
"FlashInfer available - will benchmark fused operations"
,
)
else
:
logger
.
info
(
"FlashInfer not available - only benchmarking standard operations"
)
# Convert dtype strings to torch dtypes
dtype_map
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
,
"float32"
:
torch
.
float32
,
}
dtypes
=
[
dtype_map
[
dt
]
for
dt
in
args
.
dtypes
]
# Test configurations
residual_options
=
[
True
]
if
not
args
.
no_residual
else
[
False
]
configs
=
list
(
itertools
.
product
(
args
.
num_tokens
,
dtypes
,
residual_options
))
# Setup FlashInfer workspace if available
ipc_handles
=
None
allreduce_params
=
None
if
flashinfer_comm
is
not
None
:
# Use the largest hidden dimension for workspace setup
max_num_token
=
_FI_MAX_SIZES
.
get
(
world_size
)
//
(
args
.
hidden_dim
*
world_size
*
2
)
ipc_handles
,
workspace_tensor
=
setup_flashinfer_workspace
(
world_size
,
rank
,
args
.
hidden_dim
,
max_num_token
)
if
workspace_tensor
is
not
None
:
allreduce_params
=
FlashInferFusedAllReduceParams
(
rank
=
rank
,
world_size
=
world_size
,
max_token_num
=
max_num_token
,
)
# Collect all results for markdown export
all_results
=
[]
try
:
# Run benchmarks
for
num_tokens
,
dtype
,
use_residual
in
configs
:
if
rank
==
0
:
logger
.
info
(
"
\n
Testing: num_tokens=%s, hidden_dim=%s, dtype=%s, residual=%s"
,
num_tokens
,
args
.
hidden_dim
,
dtype
,
use_residual
,
)
results
=
run_benchmarks
(
num_tokens
,
args
.
hidden_dim
,
dtype
,
use_residual
,
allreduce_params
,
quant_modes
=
quant_modes
,
no_oneshot
=
args
.
no_oneshot
,
)
# Store results for markdown export
if
rank
==
0
:
# Calculate input size in MB
input_size_mb
=
(
num_tokens
*
args
.
hidden_dim
*
torch
.
finfo
(
dtype
).
bits
)
/
(
8
*
1024
*
1024
)
all_results
.
append
(
{
"num_tokens"
:
num_tokens
,
"hidden_dim"
:
args
.
hidden_dim
,
"dtype"
:
str
(
dtype
).
replace
(
"torch."
,
""
),
"use_residual"
:
use_residual
,
"quant_modes"
:
sorted
(
list
(
quant_modes
)),
"input_size_mb"
:
input_size_mb
,
"results"
:
results
,
}
)
print_results
(
results
,
num_tokens
,
args
.
hidden_dim
,
dtype
,
use_residual
,
quant_modes
,
input_size_mb
,
)
# Save results to markdown file
if
args
.
output_file
and
rank
==
0
:
save_results_to_file
(
all_results
,
world_size
,
args
,
rank
)
finally
:
# Cleanup
if
ipc_handles
is
not
None
:
cleanup_flashinfer_workspace
(
ipc_handles
)
dist
.
barrier
()
if
__name__
==
"__main__"
:
main
()
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
View file @
41199996
...
...
@@ -13,11 +13,11 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts
,
fused_topk
,
)
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
[
"
nm-testing
/Mixtral-8x7B-Instruct-v0.1"
,
"
nm-testing/d
eep
s
eek
v
2-
l
ite"
,
"
mistralai
/Mixtral-8x7B-Instruct-v0.1"
,
"
deepseek-ai/D
eep
S
eek
-V
2-
L
ite"
,
"ibm-granite/granite-3.0-1b-a400m"
,
"ibm-granite/granite-3.0-3b-a800m"
,
]
...
...
benchmarks/kernels/benchmark_layernorm.py
View file @
41199996
...
...
@@ -7,7 +7,8 @@ import torch
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.platforms
import
current_platform
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
@
torch
.
inference_mode
()
...
...
benchmarks/kernels/benchmark_lora.py
View file @
41199996
...
...
@@ -6,11 +6,12 @@ import copy
import
json
import
pickle
import
time
from
collections.abc
import
Callable
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
itertools
import
product
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Optional
from
typing
import
Any
import
torch
import
torch.utils.benchmark
as
TBenchmark
...
...
@@ -18,13 +19,24 @@ from torch.utils.benchmark import Measurement as TMeasurement
from
utils
import
ArgPool
,
Bench
,
CudaGraphBenchParams
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm.triton_utils
import
HAS_TRITON
from
vllm.lora.ops.triton_ops.utils
import
get_lora_op_configs
from
vllm.triton_utils
import
HAS_TRITON
,
triton
if
HAS_TRITON
:
from
vllm.lora.ops.triton_ops
import
LoRAKernelMeta
,
lora_expand
,
lora_shrink
from
vllm.lora.ops.triton_ops
import
(
## added fused_moe_lora
LoRAKernelMeta
,
fused_moe_lora_expand
,
fused_moe_lora_shrink
,
lora_expand
,
lora_shrink
,
)
from
vllm.lora.ops.triton_ops.fused_moe_lora_op
import
(
_LORA_PTR_DICT
,
## added _LORA_PTR_DICT for fused_moe_lora
)
from
vllm.lora.ops.triton_ops.utils
import
_LORA_A_PTR_DICT
,
_LORA_B_PTR_DICT
from
vllm.utils
import
FlexibleArgumentParser
from
vllm
import
_custom_ops
as
ops
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.math_utils
import
round_up
DEFAULT_MODELS
=
list
(
WEIGHT_SHAPES
.
keys
())
DEFAULT_TP_SIZES
=
[
1
]
...
...
@@ -58,6 +70,8 @@ DEFAULT_NUM_LORAS = [1, 2, 3, 4]
DEFAULT_SORT_BY_LORA_IDS
=
[
False
,
True
]
DEFAULT_SEQ_LENGTHS
=
[
1
]
DEFAULT_EXPAND_FN_ADD_INPUTS
=
[
True
,
False
]
DEFAULT_TOP_K_NUMS
=
[
1
]
# Added for MoE LoRA top_k
DEFAULT_NUM_EXPERTS
=
[
8
]
# Added for MoE LoRA num_experts
# Utilities
...
...
@@ -158,7 +172,7 @@ def ref_group_gemm(
seq_lens_cpu
:
torch
.
Tensor
,
prompt_lora_mapping_cpu
:
torch
.
Tensor
,
scaling
:
float
,
add_inputs
:
Optional
[
bool
]
,
add_inputs
:
bool
|
None
,
):
"""
Torch group gemm reference implementation to test correctness of
...
...
@@ -190,6 +204,11 @@ class OpType(Enum):
LORA_SHRINK
=
auto
()
LORA_EXPAND
=
auto
()
## Adding support for fused moe lora
FUSED_MOE_LORA_GATE_UP_SHRINK
=
auto
()
## Gate/Up projection variant with shrink
FUSED_MOE_LORA_GATE_UP_EXPAND
=
auto
()
## Gate/Up projection variant with expand
FUSED_MOE_LORA_DOWN_SHRINK
=
auto
()
## Down projection variant with shrink
FUSED_MOE_LORA_DOWN_EXPAND
=
auto
()
## Down projection variant with expand
@
staticmethod
def
from_str
(
s
:
str
)
->
"OpType"
:
...
...
@@ -197,6 +216,15 @@ class OpType(Enum):
return
OpType
.
LORA_SHRINK
if
s
.
lower
()
==
"lora_expand"
:
return
OpType
.
LORA_EXPAND
# Adding support for fused moe lora, both in gate_up and down
if
s
.
lower
()
==
"fused_moe_lora_gate_up_shrink"
:
## Gate/Up variant with shrink
return
OpType
.
FUSED_MOE_LORA_GATE_UP_SHRINK
if
s
.
lower
()
==
"fused_moe_lora_gate_up_expand"
:
## Gate/Up variant with expand
return
OpType
.
FUSED_MOE_LORA_GATE_UP_EXPAND
if
s
.
lower
()
==
"fused_moe_lora_down_shrink"
:
## Down variant with shrink
return
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
if
s
.
lower
()
==
"fused_moe_lora_down_expand"
:
## Down variant with expand
return
OpType
.
FUSED_MOE_LORA_DOWN_EXPAND
raise
ValueError
(
f
"Unrecognized str
{
s
}
to convert to OpType"
)
def
is_shrink_fn
(
self
)
->
bool
:
...
...
@@ -205,19 +233,56 @@ class OpType(Enum):
def
is_expand_fn
(
self
)
->
bool
:
return
self
in
[
OpType
.
LORA_EXPAND
]
def
is_fused_moe_lora_fn
(
self
)
->
bool
:
## adding for fused MoE LoRA
return
self
in
[
OpType
.
FUSED_MOE_LORA_GATE_UP_SHRINK
,
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
,
OpType
.
FUSED_MOE_LORA_GATE_UP_EXPAND
,
OpType
.
FUSED_MOE_LORA_DOWN_EXPAND
,
]
def
is_fused_moe_lora_gate_up_fn
(
self
,
)
->
bool
:
## adding for fused MoE LoRA Gate/Up
return
self
in
[
OpType
.
FUSED_MOE_LORA_GATE_UP_SHRINK
,
OpType
.
FUSED_MOE_LORA_GATE_UP_EXPAND
,
]
def
is_fused_moe_lora_down_fn
(
self
)
->
bool
:
## adding for fused MoE LoRA Down
return
self
in
[
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
,
OpType
.
FUSED_MOE_LORA_DOWN_EXPAND
,
]
def
is_fused_moe_lora_shrink_fn
(
self
)
->
bool
:
return
self
in
[
OpType
.
FUSED_MOE_LORA_GATE_UP_SHRINK
,
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
,
]
def
is_fused_moe_lora_expand_fn
(
self
)
->
bool
:
return
self
in
[
OpType
.
FUSED_MOE_LORA_GATE_UP_EXPAND
,
OpType
.
FUSED_MOE_LORA_DOWN_EXPAND
,
]
def
num_slices
(
self
)
->
list
[
int
]:
if
self
.
is_fused_moe_lora_gate_up_fn
():
return
[
2
]
elif
self
.
is_fused_moe_lora_down_fn
():
return
[
1
]
return
[
1
,
2
,
3
]
def
mkn
(
self
,
batch_size
:
int
,
seq_length
:
int
,
hidden_size
:
int
,
lora_rank
:
int
)
->
tuple
[
int
,
int
,
int
]:
num_tokens
=
batch_size
*
seq_length
if
self
.
is_shrink_fn
():
if
self
.
is_shrink_fn
()
or
self
.
is_fused_moe_lora_fn
()
:
m
=
num_tokens
k
=
hidden_size
n
=
lora_rank
else
:
assert
self
.
is_expand_fn
()
elif
self
.
is_expand_fn
():
m
=
num_tokens
k
=
lora_rank
n
=
hidden_size
...
...
@@ -231,9 +296,36 @@ class OpType(Enum):
"""
if
self
.
is_shrink_fn
():
return
op_dtype
,
op_dtype
,
torch
.
float32
else
:
assert
self
.
is_expand_fn
()
elif
self
.
is_expand_fn
():
return
torch
.
float32
,
op_dtype
,
op_dtype
else
:
assert
self
.
is_fused_moe_lora_fn
()
return
op_dtype
,
op_dtype
,
op_dtype
def
matmul_shapes_fused_moe_lora
(
self
,
m
:
int
,
n
:
int
,
k
:
int
,
num_loras
:
int
,
num_slices
:
int
,
top_k_num
:
int
,
num_experts
:
int
,
)
->
tuple
[
tuple
[
int
],
tuple
[
int
],
tuple
[
int
],
tuple
[
int
]]:
if
self
.
is_fused_moe_lora_shrink_fn
():
input_shape
=
(
(
m
*
top_k_num
,
n
)
if
self
in
[
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
]
else
(
m
,
n
)
)
output_shape
=
(
num_slices
,
m
,
top_k_num
,
k
)
weight_shape
=
(
num_loras
,
num_experts
,
k
,
n
)
else
:
assert
self
.
is_fused_moe_lora_expand_fn
()
input_shape
=
(
num_slices
,
m
,
top_k_num
,
k
)
output_shape
=
(
m
,
top_k_num
,
n
*
num_slices
)
weight_shape
=
(
num_loras
,
num_experts
,
n
,
k
)
return
(
input_shape
,
weight_shape
,
output_shape
)
def
matmul_shapes
(
self
,
...
...
@@ -243,6 +335,8 @@ class OpType(Enum):
lora_rank
:
int
,
num_loras
:
int
,
num_slices
:
int
,
top_k_num
:
int
|
None
=
None
,
num_experts
:
int
|
None
=
None
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
"""
Given num_slices, return the shapes of the A, B, and C matrices
...
...
@@ -257,6 +351,16 @@ class OpType(Enum):
if
self
in
[
OpType
.
LORA_EXPAND
]:
# LoRA expand kernels support num_slices inherently in the kernel
return
((
num_slices
,
m
,
k
),
b_shape
,
(
m
,
n
*
num_slices
))
if
self
.
is_fused_moe_lora_fn
():
return
self
.
matmul_shapes_fused_moe_lora
(
m
,
k
,
n
,
num_loras
,
num_slices
,
top_k_num
,
num_experts
,
)
raise
ValueError
(
f
"Unrecognized op_type
{
self
}
"
)
def
bench_fn
(
self
)
->
Callable
:
...
...
@@ -264,6 +368,16 @@ class OpType(Enum):
return
lora_shrink
if
self
==
OpType
.
LORA_EXPAND
:
return
lora_expand
if
self
in
[
OpType
.
FUSED_MOE_LORA_GATE_UP_SHRINK
,
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
,
]:
return
fused_moe_lora_shrink
if
self
in
[
OpType
.
FUSED_MOE_LORA_GATE_UP_EXPAND
,
OpType
.
FUSED_MOE_LORA_DOWN_EXPAND
,
]:
return
fused_moe_lora_expand
raise
ValueError
(
f
"Unrecognized optype
{
self
}
"
)
...
...
@@ -316,8 +430,10 @@ class BenchmarkContext:
lora_rank
:
int
sort_by_lora_id
:
bool
dtype
:
torch
.
dtype
seq_length
:
Optional
[
int
]
=
None
num_slices
:
Optional
[
int
]
=
None
# num_slices for slice based ops
seq_length
:
int
|
None
=
None
num_experts
:
int
|
None
=
None
# num_experts for MoE based ops
top_k_num
:
int
|
None
=
None
# top_k for MoE based ops
num_slices
:
int
|
None
=
None
# num_slices for slice based ops
def
with_seq_length
(
self
,
seq_length
:
int
)
->
"BenchmarkContext"
:
ctx
=
copy
.
copy
(
self
)
...
...
@@ -372,6 +488,11 @@ class BenchmarkTensors:
f
"
{
dtype_to_str
(
self
.
output
.
dtype
)
}
"
)
def
get_num_tokens
(
self
,
size
:
int
,
top_k_num
:
int
,
op_type
:
OpType
):
return
(
size
*
top_k_num
if
op_type
in
[
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
]
else
size
)
@
staticmethod
def
make
(
ctx
:
BenchmarkContext
,
op_type
:
OpType
,
device
:
str
=
"cuda"
...
...
@@ -384,6 +505,8 @@ class BenchmarkTensors:
ctx
.
lora_rank
,
ctx
.
num_loras
,
ctx
.
num_slices
,
ctx
.
top_k_num
,
ctx
.
num_experts
,
)
a_type
,
b_type
,
c_type
=
op_type
.
matmul_dtypes
(
ctx
.
dtype
)
input_tensor
,
lora_weights
,
output_tensor
=
make_rand_tensors
(
...
...
@@ -431,17 +554,27 @@ class BenchmarkTensors:
prompt_lora_indices_tensor
,
)
def
sanity_check
(
self
)
->
None
:
def
sanity_check
(
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
)
->
None
:
"""
Fails asserts when non-conformality is detected.
"""
num_tokens
=
self
.
input
.
shape
[
-
2
]
num_tokens
=
(
self
.
input
.
shape
[
1
]
if
op_type
.
is_fused_moe_lora_expand_fn
()
else
self
.
input
.
shape
[
-
2
]
)
# check metadata tensors
assert
torch
.
sum
(
self
.
seq_lens
)
==
num_tokens
## In down shrink case, each token is repeated top_k_num times
assert
num_tokens
==
self
.
get_num_tokens
(
torch
.
sum
(
self
.
seq_lens
),
ctx
.
top_k_num
,
op_type
),
f
"Expected
{
num_tokens
}
tokens, but got
{
torch
.
sum
(
self
.
seq_lens
)
}
"
num_seqs
=
self
.
seq_lens
.
shape
[
0
]
# assert self.seq_start_loc.shape[0] == num_seqs
## In down shrink case, each prompt corresponds to top_k_num sequences
assert
self
.
prompt_lora_mapping
.
shape
[
0
]
==
num_seqs
assert
self
.
lora_kernel_meta
.
token_lora_mapping
.
shape
[
0
]
==
num_tokens
assert
self
.
get_num_tokens
(
self
.
lora_kernel_meta
.
token_lora_mapping
.
shape
[
0
],
ctx
.
top_k_num
,
op_type
)
def
to_device
(
self
,
device
:
str
):
"""
...
...
@@ -470,21 +603,111 @@ class BenchmarkTensors:
to_device
(
field
)
if
field_name
!=
"no_lora_flag_cpu"
else
field
,
)
def
metadata
(
self
)
->
tuple
[
int
,
int
,
int
]:
def
metadata
(
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
)
->
tuple
[
int
,
int
,
int
]:
"""
Return num_seqs, num_tokens and max_seq_len
"""
num_seqs
=
self
.
seq_lens
.
shape
[
0
]
num_tokens
=
self
.
lora_kernel_meta
.
token_lora_mapping
.
shape
[
0
]
num_tokens
=
self
.
get_num_tokens
(
self
.
lora_kernel_meta
.
token_lora_mapping
.
shape
[
0
],
ctx
.
top_k_num
,
op_type
)
max_seq_len
=
torch
.
max
(
self
.
seq_lens
).
item
()
num_slices
=
len
(
self
.
lora_weights_lst
)
return
num_seqs
,
num_tokens
,
max_seq_len
,
num_slices
def
as_lora_shrink_kwargs
(
self
)
->
dict
[
str
,
Any
]:
self
.
sanity_check
()
def
fused_moe_lora_data_prepare
(
self
,
block_size
:
int
,
token_lora_mapping
:
torch
.
Tensor
,
ctx
:
BenchmarkContext
,
):
def
moe_lora_align_block_size
(
topk_ids
:
torch
.
Tensor
,
token_lora_mapping
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
,
max_loras
:
int
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
pad_sorted_ids
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Aligns tokens and experts into block-sized chunks for LoRA-based
mixture-of-experts (MoE) execution.
"""
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
if
pad_sorted_ids
:
max_num_tokens_padded
=
round_up
(
max_num_tokens_padded
,
block_size
)
sorted_ids
=
torch
.
empty
(
(
max_loras
*
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
,
)
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
# Expert ids must be set default to -1 to prevent a blank block
expert_ids
=
torch
.
empty
(
(
max_loras
*
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
,
)
num_tokens_post_pad
=
torch
.
empty
(
(
max_loras
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
ops
.
moe_lora_align_block_size
(
topk_ids
,
token_lora_mapping
,
num_experts
,
block_size
,
max_loras
,
max_num_tokens_padded
,
max_num_m_blocks
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
)
if
expert_map
is
not
None
:
expert_ids
=
expert_map
[
expert_ids
]
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
num_tokens
=
ctx
.
batch_size
curr_topk_ids
=
torch
.
randint
(
0
,
ctx
.
num_experts
,
(
num_tokens
,
ctx
.
top_k_num
),
device
=
"cuda"
,
dtype
=
torch
.
int32
,
)
topk_weights
=
torch
.
randint
(
0
,
ctx
.
num_experts
,
(
num_tokens
,
ctx
.
top_k_num
),
device
=
"cuda"
,
dtype
=
torch
.
int32
,
)
(
sorted_token_ids_lora
,
expert_ids_lora
,
num_tokens_post_padded_lora
)
=
(
moe_lora_align_block_size
(
topk_ids
=
curr_topk_ids
,
token_lora_mapping
=
token_lora_mapping
,
block_size
=
block_size
,
num_experts
=
ctx
.
num_experts
,
max_loras
=
ctx
.
num_loras
,
)
)
sorted_token_ids
=
sorted_token_ids_lora
.
view
(
ctx
.
num_loras
,
-
1
)
expert_ids
=
expert_ids_lora
.
view
(
ctx
.
num_loras
,
-
1
)
num_tokens_post_padded
=
num_tokens_post_padded_lora
return
(
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
)
def
as_lora_shrink_kwargs
(
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
)
->
dict
[
str
,
Any
]:
self
.
sanity_check
(
ctx
,
op_type
)
self
.
to_device
(
self
.
input
.
device
)
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
()
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
(
ctx
,
op_type
)
# Sanity check matrix shapes.
i_shape
,
lw_shape
,
o_shape
=
(
...
...
@@ -519,11 +742,13 @@ class BenchmarkTensors:
"no_lora_flag_cpu"
:
self
.
lora_kernel_meta
.
no_lora_flag_cpu
,
}
def
as_lora_expand_kwargs
(
self
,
add_inputs
:
bool
)
->
dict
[
str
,
Any
]:
self
.
sanity_check
()
def
as_lora_expand_kwargs
(
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
,
add_inputs
:
bool
)
->
dict
[
str
,
Any
]:
self
.
sanity_check
(
ctx
,
op_type
)
self
.
to_device
(
self
.
input
.
device
)
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
()
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
(
ctx
,
op_type
)
# Sanity check matrix shapes.
i_shape
,
lw_shape
,
o_shape
=
(
...
...
@@ -560,22 +785,177 @@ class BenchmarkTensors:
"no_lora_flag_cpu"
:
self
.
lora_kernel_meta
.
no_lora_flag_cpu
,
}
def
as_fused_moe_lora_shrink_kwargs
(
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
)
->
dict
[
str
,
Any
]:
self
.
sanity_check
(
ctx
,
op_type
)
self
.
to_device
(
self
.
input
.
device
)
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
(
ctx
,
op_type
)
# Sanity check matrix shapes.
i_shape
,
lw_shape
,
o_shape
=
(
self
.
input
.
shape
,
self
.
lora_weights_lst
[
0
].
shape
,
self
.
output
.
shape
,
)
# Expected input shape : [num_tokens, hidden_size] for gate_up
# Expected input shape : [top_k_num * num_tokens, hidden_size] for down
assert
len
(
i_shape
)
==
2
assert
i_shape
[
0
]
==
num_tokens
hidden_size
=
i_shape
[
1
]
# Expected lora weight shape [max_lora, num_experts, lora_rank, hidden_size]
assert
len
(
lw_shape
)
==
4
assert
lw_shape
[
-
1
]
==
hidden_size
lora_rank
=
lw_shape
[
-
2
]
# Expected output shape : [num_slices, num_tokens, top_k_num, lora_rank]
assert
len
(
o_shape
)
==
4
assert
(
o_shape
==
(
num_slices
,
num_tokens
//
ctx
.
top_k_num
,
ctx
.
top_k_num
,
lora_rank
)
if
op_type
in
[
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
]
else
o_shape
==
(
num_slices
,
num_tokens
,
ctx
.
top_k_num
,
lora_rank
)
)
kernel_config
=
get_lora_op_configs
(
op_type
.
name
.
lower
(),
max_loras
=
lw_shape
[
0
],
batch
=
num_tokens
,
hidden_size
=
hidden_size
,
rank
=
lora_rank
,
num_slices
=
num_slices
,
add_inputs
=
False
,
)
(
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
)
=
(
self
.
fused_moe_lora_data_prepare
(
block_size
=
kernel_config
[
"BLOCK_SIZE_M"
],
token_lora_mapping
=
self
.
lora_kernel_meta
.
token_lora_mapping
,
ctx
=
ctx
,
)
)
return
{
"qcurr_hidden_states"
:
self
.
input
,
"lora_a_stacked"
:
self
.
lora_weights_lst
,
"a_intermediate_cache1"
:
self
.
output
,
"topk_weights"
:
topk_weights
,
"sorted_token_ids"
:
sorted_token_ids
,
"expert_ids"
:
expert_ids
,
"num_tokens_post_padded"
:
num_tokens_post_padded
,
"top_k_num"
:
ctx
.
top_k_num
,
"device"
:
self
.
input
.
device
,
"N"
:
lora_rank
,
"M"
:
topk_weights
.
shape
[
0
],
"EM"
:
sorted_token_ids
.
shape
[
1
],
"K"
:
self
.
input
.
shape
[
1
],
"num_tokens"
:
num_tokens
,
"num_experts"
:
ctx
.
num_experts
,
"num_slices"
:
num_slices
,
"shrink_block_size_m"
:
kernel_config
[
"BLOCK_SIZE_M"
],
"shrink_block_size_n"
:
kernel_config
[
"BLOCK_SIZE_N"
],
"shrink_block_size_k"
:
kernel_config
[
"BLOCK_SIZE_K"
],
"shrink_group_size_m"
:
kernel_config
[
"GROUP_SIZE_M"
],
"shrink_num_warps"
:
kernel_config
[
"NUM_WARPS"
],
"shrink_num_stages"
:
kernel_config
[
"NUM_STAGES"
],
"shrink_split_k"
:
kernel_config
.
get
(
"SPLIT_K"
,
1
),
"mul_routed_weight"
:
op_type
.
is_fused_moe_lora_down_fn
(),
}
def
as_fused_moe_lora_expand_kwargs
(
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
)
->
dict
[
str
,
Any
]:
self
.
sanity_check
(
ctx
,
op_type
)
self
.
to_device
(
self
.
input
.
device
)
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
(
ctx
,
op_type
)
# Sanity check matrix shapes.
i_shape
,
lw_shape
,
o_shape
=
(
self
.
input
.
shape
,
self
.
lora_weights_lst
[
0
].
shape
,
self
.
output
.
shape
,
)
# Expected input shape : [num_slices, num_tokens, top_k_num, lora_rank]
assert
len
(
i_shape
)
==
4
assert
i_shape
[
0
]
==
num_slices
assert
i_shape
[
1
]
==
num_tokens
lora_rank
=
i_shape
[
-
1
]
# Expected lora weight shape : [num_loras, num_experts, hidden_size, lora_rank]
assert
len
(
lw_shape
)
==
4
assert
lw_shape
[
-
1
]
==
lora_rank
hidden_size
=
lw_shape
[
-
2
]
# Expected output shape : [num_tokens, top_k_num, hidden_size * num_slices]
assert
len
(
o_shape
)
==
3
assert
o_shape
==
(
num_tokens
,
ctx
.
top_k_num
,
hidden_size
*
num_slices
)
kernel_config
=
get_lora_op_configs
(
op_type
.
name
.
lower
(),
max_loras
=
lw_shape
[
0
],
batch
=
num_tokens
,
hidden_size
=
hidden_size
,
rank
=
lora_rank
,
num_slices
=
num_slices
,
add_inputs
=
False
,
)
(
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
)
=
(
self
.
fused_moe_lora_data_prepare
(
block_size
=
kernel_config
[
"BLOCK_SIZE_M"
],
token_lora_mapping
=
self
.
lora_kernel_meta
.
token_lora_mapping
,
ctx
=
ctx
,
)
)
return
{
"a_intermediate_cache1"
:
self
.
input
,
"lora_b_stacked"
:
self
.
lora_weights_lst
,
"output"
:
self
.
output
,
"topk_weights"
:
topk_weights
,
"sorted_token_ids"
:
sorted_token_ids
,
"expert_ids"
:
expert_ids
,
"num_tokens_post_padded"
:
num_tokens_post_padded
,
"top_k_num"
:
ctx
.
top_k_num
,
"device"
:
self
.
input
.
device
,
"N"
:
lora_rank
,
"M"
:
topk_weights
.
shape
[
0
],
"EM"
:
sorted_token_ids
.
shape
[
1
],
"K"
:
self
.
input
.
shape
[
1
],
"num_tokens"
:
num_tokens
,
"num_experts"
:
ctx
.
num_experts
,
"num_slices"
:
num_slices
,
"max_lora_rank"
:
lora_rank
,
"w1_output_dim_size"
:
lw_shape
[
2
],
"expand_block_size_m"
:
kernel_config
[
"BLOCK_SIZE_M"
],
"expand_block_size_n"
:
kernel_config
[
"BLOCK_SIZE_N"
],
"expand_block_size_k"
:
kernel_config
[
"BLOCK_SIZE_K"
],
"expand_group_size_m"
:
kernel_config
[
"GROUP_SIZE_M"
],
"expand_num_warps"
:
kernel_config
[
"NUM_WARPS"
],
"expand_num_stages"
:
kernel_config
[
"NUM_STAGES"
],
"expand_split_k"
:
kernel_config
.
get
(
"SPLIT_K"
,
1
),
"mul_routed_weight"
:
op_type
.
is_fused_moe_lora_down_fn
(),
}
def
bench_fn_kwargs
(
self
,
op_type
:
OpType
,
add_inputs
:
Optional
[
bool
]
=
None
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
,
add_inputs
:
bool
|
None
=
None
)
->
dict
[
str
,
Any
]:
if
op_type
.
is_shrink_fn
():
if
op_type
.
is_shrink_fn
()
or
op_type
.
is_fused_moe_lora_fn
()
:
assert
add_inputs
is
None
else
:
assert
add_inputs
is
not
None
if
op_type
==
OpType
.
LORA_SHRINK
:
return
self
.
as_lora_shrink_kwargs
()
return
self
.
as_lora_shrink_kwargs
(
ctx
,
op_type
)
if
op_type
==
OpType
.
LORA_EXPAND
:
return
self
.
as_lora_expand_kwargs
(
add_inputs
)
return
self
.
as_lora_expand_kwargs
(
ctx
,
op_type
,
add_inputs
)
if
op_type
.
is_fused_moe_lora_shrink_fn
():
return
self
.
as_fused_moe_lora_shrink_kwargs
(
ctx
,
op_type
)
if
op_type
.
is_fused_moe_lora_expand_fn
():
return
self
.
as_fused_moe_lora_expand_kwargs
(
ctx
,
op_type
)
raise
ValueError
(
f
"Unrecognized optype
{
self
}
"
)
def
test_correctness
(
self
,
op_type
:
OpType
,
expand_fn_add_inputs
:
Optional
[
bool
]
self
,
op_type
:
OpType
,
expand_fn_add_inputs
:
bool
|
None
)
->
bool
:
"""
Test correctness of op_type implementation against a grouped gemm
...
...
@@ -611,12 +991,12 @@ def bench_optype(
ctx
:
BenchmarkContext
,
arg_pool_size
:
int
,
op_type
:
OpType
,
cuda_graph_nops
:
Optional
[
int
]
=
None
,
expand_fn_add_inputs
:
Optional
[
bool
]
=
None
,
cuda_graph_nops
:
int
|
None
=
None
,
expand_fn_add_inputs
:
bool
|
None
=
None
,
test_correctness
:
bool
=
False
,
)
->
TMeasurement
:
assert
arg_pool_size
>=
1
if
op_type
.
is_shrink_fn
():
if
op_type
.
is_shrink_fn
()
or
op_type
.
is_fused_moe_lora_fn
()
:
assert
expand_fn_add_inputs
is
None
else
:
assert
expand_fn_add_inputs
is
not
None
...
...
@@ -626,23 +1006,30 @@ def bench_optype(
BenchmarkTensors
.
make
(
ctx
,
op_type
)
for
_
in
range
(
arg_pool_size
)
]
for
bt
in
bench_tensors
:
bt
.
sanity_check
()
bt
.
sanity_check
(
ctx
,
op_type
)
# Test correctness of our implementation.
if
test_correctness
:
assert
op_type
in
[
OpType
.
LORA_SHRINK
,
OpType
.
LORA_EXPAND
],
(
f
"Correctness testing is not supported for
{
op_type
.
name
}
."
)
assert
all
(
[
bt
.
test_correctness
(
op_type
,
expand_fn_add_inputs
)
for
bt
in
bench_tensors
]
[
bt
.
test_correctness
(
ctx
,
op_type
,
expand_fn_add_inputs
)
for
bt
in
bench_tensors
]
)
# BenchmarkTensors -> dict (kwargs)
kwargs_list
=
[
bt
.
bench_fn_kwargs
(
op_type
,
add_inputs
=
expand_fn_add_inputs
)
bt
.
bench_fn_kwargs
(
ctx
,
op_type
,
add_inputs
=
expand_fn_add_inputs
)
for
bt
in
bench_tensors
]
# Clear LoRA optimization hash-maps.
_LORA_A_PTR_DICT
.
clear
()
_LORA_B_PTR_DICT
.
clear
()
_LORA_PTR_DICT
.
clear
()
# Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up
for
kwargs
in
kwargs_list
:
op_type
.
bench_fn
()(
**
kwargs
)
...
...
@@ -679,7 +1066,7 @@ def bench_torch_mm(
ctx
:
BenchmarkContext
,
arg_pool_size
:
int
,
op_type
:
OpType
,
cuda_graph_nops
:
Optional
[
int
]
=
None
,
cuda_graph_nops
:
int
|
None
=
None
,
)
->
TMeasurement
:
"""
Benchmark basic torch.mm as a roofline.
...
...
@@ -744,7 +1131,7 @@ def use_cuda_graph_recommendation() -> str:
"""
def
print_timers
(
timers
:
list
[
TMeasurement
],
args
:
Optional
[
argparse
.
Namespace
]
=
None
):
def
print_timers
(
timers
:
list
[
TMeasurement
],
args
:
argparse
.
Namespace
|
None
=
None
):
compare
=
TBenchmark
.
Compare
(
timers
)
compare
.
print
()
...
...
@@ -792,7 +1179,9 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
# Benchmark bench_op
expand_fn_add_inputs
=
(
[
None
]
if
bench_op
.
is_shrink_fn
()
else
args
.
expand_fn_add_inputs
[
None
]
if
bench_op
.
is_shrink_fn
()
or
bench_op
.
is_fused_moe_lora_fn
()
else
args
.
expand_fn_add_inputs
)
for
add_input_arg
in
expand_fn_add_inputs
:
seq_len_timers
.
append
(
...
...
@@ -830,12 +1219,22 @@ def as_benchmark_contexts(
hidden_sizes
:
list
[
int
],
lora_ranks
:
list
[
int
],
args
:
argparse
.
Namespace
)
->
list
[
BenchmarkContext
]:
ctxs
:
list
[
BenchmarkContext
]
=
[]
for
batch_size
,
hidden_size
,
lora_rank
,
num_loras
,
sort_by_lora_id
in
product
(
# noqa
for
(
batch_size
,
hidden_size
,
lora_rank
,
num_loras
,
sort_by_lora_id
,
top_k_num
,
num_experts
,
)
in
product
(
# noqa
args
.
batch_sizes
,
list
(
hidden_sizes
),
lora_ranks
,
args
.
num_loras
,
args
.
sort_by_lora_id
,
args
.
top_k_nums
,
args
.
num_experts
,
):
ctxs
.
append
(
BenchmarkContext
(
...
...
@@ -850,6 +1249,8 @@ def as_benchmark_contexts(
seq_length
=
None
,
sort_by_lora_id
=
sort_by_lora_id
,
dtype
=
args
.
dtype
,
top_k_num
=
top_k_num
,
num_experts
=
num_experts
,
# To be filled based on the OpType to benchmark
num_slices
=
None
,
)
...
...
@@ -1011,6 +1412,22 @@ if __name__ == "__main__":
),
)
p
.
add_argument
(
"--top-k-nums"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_TOP_K_NUMS
,
help
=
"Top-K values for MoE LoRA operations"
,
)
p
.
add_argument
(
"--num-experts"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_NUM_EXPERTS
,
help
=
"Number of experts for MoE LoRA operations"
,
)
parser
=
FlexibleArgumentParser
(
description
=
f
"""
Benchmark LoRA kernels:
...
...
benchmarks/kernels/benchmark_machete.py
View file @
41199996
...
...
@@ -8,10 +8,9 @@ import math
import
os
import
pickle
as
pkl
import
time
from
collections.abc
import
Iterable
from
collections.abc
import
Callable
,
Iterable
from
dataclasses
import
dataclass
from
itertools
import
product
from
typing
import
Callable
,
Optional
import
pandas
as
pd
import
torch
...
...
@@ -34,7 +33,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
quantize_weights
,
)
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
[
"meta-llama/Llama-3-8b"
,
"meta-llama/Llama-2-70b-hf"
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
]
...
...
@@ -63,23 +62,23 @@ class BenchmarkTensors:
a
:
torch
.
Tensor
w_q
:
torch
.
Tensor
group_size
:
Optional
[
int
]
group_size
:
int
|
None
wtype
:
ScalarType
w_g_s
:
torch
.
Tensor
w_g_zp
:
Optional
[
torch
.
Tensor
]
w_ch_s
:
Optional
[
torch
.
Tensor
]
w_tok_s
:
Optional
[
torch
.
Tensor
]
w_g_zp
:
torch
.
Tensor
|
None
w_ch_s
:
torch
.
Tensor
|
None
w_tok_s
:
torch
.
Tensor
|
None
@
dataclass
class
TypeConfig
:
act_type
:
torch
.
dtype
weight_type
:
ScalarType
output_type
:
Optional
[
torch
.
dtype
]
group_scale_type
:
Optional
[
torch
.
dtype
]
group_zero_type
:
Optional
[
torch
.
dtype
]
channel_scale_type
:
Optional
[
torch
.
dtype
]
token_scale_type
:
Optional
[
torch
.
dtype
]
output_type
:
torch
.
dtype
|
None
group_scale_type
:
torch
.
dtype
|
None
group_zero_type
:
torch
.
dtype
|
None
channel_scale_type
:
torch
.
dtype
|
None
token_scale_type
:
torch
.
dtype
|
None
def
rand_data
(
shape
,
dtype
=
torch
.
float16
,
scale
=
1
):
...
...
@@ -93,8 +92,8 @@ def quantize_and_pack(
atype
:
torch
.
dtype
,
w
:
torch
.
Tensor
,
wtype
:
ScalarType
,
stype
:
Optional
[
torch
.
dtype
]
,
group_size
:
Optional
[
int
]
,
stype
:
torch
.
dtype
|
None
,
group_size
:
int
|
None
,
zero_points
:
bool
=
False
,
):
assert
wtype
.
is_integer
(),
"TODO: support floating point weights"
...
...
@@ -113,7 +112,7 @@ def quantize_and_pack(
def
create_bench_tensors
(
shape
:
tuple
[
int
,
int
,
int
],
types
:
TypeConfig
,
group_size
:
Optional
[
int
]
shape
:
tuple
[
int
,
int
,
int
],
types
:
TypeConfig
,
group_size
:
int
|
None
)
->
list
[
BenchmarkTensors
]:
m
,
n
,
k
=
shape
...
...
@@ -238,6 +237,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
b_q_weight
=
w_q
,
b_bias
=
None
,
b_scales
=
w_s
,
a_scales
=
None
,
global_scale
=
None
,
b_zeros
=
w_zp
,
g_idx
=
g_idx
,
...
...
@@ -331,8 +331,8 @@ def bench_fns(label: str, sub_label: str, description: str, fns: list[Callable])
return
res
_SWEEP_SCHEDULES_RESULTS
:
Optional
[
pd
.
DataFrame
]
=
None
_SWEEP_SCHEDULES_RESULTS_CSV
:
Optional
[
str
]
=
None
_SWEEP_SCHEDULES_RESULTS
:
pd
.
DataFrame
|
None
=
None
_SWEEP_SCHEDULES_RESULTS_CSV
:
str
|
None
=
None
def
bench
(
...
...
benchmarks/kernels/benchmark_marlin.py
View file @
41199996
...
...
@@ -44,7 +44,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
sort_weights
,
)
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
[
"meta-llama/Llama-2-7b-hf/TP1"
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
]
...
...
@@ -263,7 +263,7 @@ def bench_run(
results
.
append
(
benchmark
.
Timer
(
stmt
=
"output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)"
,
# noqa: E501
stmt
=
"output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s,
None,
marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
...
...
@@ -273,7 +273,7 @@ def bench_run(
results
.
append
(
benchmark
.
Timer
(
stmt
=
"output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)"
,
# noqa: E501
stmt
=
"output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s,
None,
marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
...
...
benchmarks/kernels/benchmark_moe.py
View file @
41199996
...
...
@@ -21,7 +21,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from
vllm.model_executor.layers.fused_moe.fused_moe
import
*
from
vllm.transformers_utils.config
import
get_config
from
vllm.triton_utils
import
triton
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
# 移除全局的 current_platform 导入,改为在需要时局部导入
# FP8_DTYPE = current_platform.fp8_dtype()
...
...
@@ -228,8 +228,8 @@ def benchmark_config(
# run()
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
start_event
=
torch
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
Event
(
enable_timing
=
True
)
latencies
:
list
[
float
]
=
[]
for
i
in
range
(
num_iters
):
...
...
@@ -253,10 +253,10 @@ def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False):
block_k_range
=
[
32
,
64
,
128
,
256
]
if
not
use_fp16
:
block_k_range
.
remove
(
16
)
# BLOCK_K=16 not supported for fp8
num_warps_range
=
[
2
,
4
,
8
]
group_m_range
=
[
1
,
16
,
32
,
64
]
num_stage_range
=
[
2
,
3
,
4
,
5
]
# waves_per_eu_range = [0]
num_warps_range
=
[
1
,
2
,
4
,
8
]
group_m_range
=
[
1
,
4
,
8
,
16
,
32
]
num_stage_range
=
[
2
]
# waves_per_eu_range = [0
, 1, 2, 4
]
# matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
# kpack_range = [1, 2] if use_fp16 else []
...
...
@@ -669,19 +669,23 @@ def main(args: argparse.Namespace):
E
=
config
.
ffn_config
.
moe_num_experts
topk
=
config
.
ffn_config
.
moe_top_k
intermediate_size
=
config
.
ffn_config
.
ffn_hidden_size
hidden_size
=
config
.
hidden_size
elif
config
.
architectures
[
0
]
==
"JambaForCausalLM"
:
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
hidden_size
=
config
.
hidden_size
elif
config
.
architectures
[
0
]
in
(
"DeepseekV2ForCausalLM"
,
"DeepseekV3ForCausalLM"
,
"DeepseekV32ForCausalLM"
,
"Glm4MoeForCausalLM"
,
"NemotronHForCausalLM"
,
):
E
=
config
.
n_routed_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
hidden_size
=
config
.
hidden_size
elif
config
.
architectures
[
0
]
in
(
"Qwen2MoeForCausalLM"
,
"Qwen3MoeForCausalLM"
,
...
...
@@ -690,14 +694,27 @@ def main(args: argparse.Namespace):
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
hidden_size
=
config
.
hidden_size
elif
config
.
architectures
[
0
]
==
"Qwen3VLMoeForConditionalGeneration"
:
text_config
=
config
.
get_text_config
()
E
=
text_config
.
num_experts
topk
=
text_config
.
num_experts_per_tok
intermediate_size
=
text_config
.
moe_intermediate_size
hidden_size
=
text_config
.
hidden_size
elif
config
.
architectures
[
0
]
in
(
"HunYuanMoEV1ForCausalLM"
):
E
=
config
.
num_experts
topk
=
config
.
moe_topk
[
0
]
intermediate_size
=
config
.
moe_intermediate_size
[
0
]
hidden_size
=
config
.
hidden_size
elif
config
.
architectures
[
0
]
in
(
"Step3VLForConditionalGeneration"
):
E
=
config
.
text_config
.
moe_num_experts
topk
=
config
.
text_config
.
moe_top_k
intermediate_size
=
config
.
text_config
.
moe_intermediate_size
elif
config
.
architectures
[
0
]
in
[
"Qwen3OmniMoeForConditionalGeneration"
]:
E
=
config
.
thinker_config
.
text_config
.
num_experts
topk
=
config
.
thinker_config
.
text_config
.
num_experts_per_tok
intermediate_size
=
config
.
thinker_config
.
text_config
.
moe_intermediate_size
hidden_size
=
config
.
thinker_config
.
text_config
.
hidden_size
else
:
# Support for llama4
config
=
config
.
get_text_config
()
...
...
@@ -705,16 +722,16 @@ def main(args: argparse.Namespace):
E
=
config
.
num_local_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
hidden_size
=
config
.
hidden_size
enable_ep
=
bool
(
args
.
enable_expert_parallel
)
if
enable_ep
:
ensure_divisibility
(
E
,
tp_size
,
"Number of experts"
)
E
=
E
//
tp_size
shard_intermediate_size
=
2
*
intermediate_size
else
:
ensure_divisibility
(
intermediate_size
,
tp_size
,
"intermediate_size"
)
ensure_divisibility
(
intermediate_size
,
args
.
tp_size
,
"intermediate_size"
)
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
hidden_size
=
config
.
hidden_size
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
torch_dtype
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
dtype
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
block_quant_shape
=
get_weight_block_size_safety
(
config
)
...
...
benchmarks/kernels/benchmark_moe_permute_unpermute.py
View file @
41199996
...
...
@@ -17,7 +17,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
)
from
vllm.model_executor.layers.fused_moe.utils
import
_fp8_quantize
from
vllm.platforms
import
current_platform
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
...
...
@@ -105,8 +105,8 @@ def benchmark_permute(
graph
.
replay
()
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
start_event
=
torch
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
Event
(
enable_timing
=
True
)
latencies
:
list
[
float
]
=
[]
for
i
in
range
(
num_iters
):
...
...
@@ -241,8 +241,8 @@ def benchmark_unpermute(
graph
.
replay
()
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
start_event
=
torch
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
Event
(
enable_timing
=
True
)
latencies
:
list
[
float
]
=
[]
for
i
in
range
(
num_iters
):
...
...
@@ -344,7 +344,7 @@ def main(args: argparse.Namespace):
topk
=
config
.
num_experts_per_tok
hidden_size
=
config
.
hidden_size
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
torch_
dtype
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
dtype
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
use_customized_permute
=
args
.
use_customized_permute
...
...
benchmarks/kernels/benchmark_mrope.py
View file @
41199996
...
...
@@ -6,7 +6,7 @@
#
# The CSV file (named with current date/time) contains these columns:
# model_name, tp_size, num_tokens, num_heads, num_kv_heads, head_dim, max_position,
#
rope_theta,
is_neox_style, rope_
scaling
, dtype, torch_mean, torch_median, torch_p99,
# is_neox_style, rope_
parameters
, dtype, torch_mean, torch_median, torch_p99,
# torch_min, torch_max, triton_mean, triton_median, triton_p99, triton_min, triton_max,
# speedup
#
...
...
@@ -39,7 +39,7 @@ import torch
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.platforms
import
current_platform
from
vllm.transformers_utils.config
import
get_config
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
...
...
@@ -86,9 +86,8 @@ def benchmark_mrope(
num_heads
:
int
,
num_kv_heads
:
int
,
max_position
:
int
=
8192
,
rope_theta
:
float
=
10000
,
is_neox_style
:
bool
=
True
,
rope_
scaling
:
dict
[
str
,
Any
]
=
None
,
rope_
parameters
:
dict
[
str
,
Any
]
|
None
=
None
,
dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
seed
:
int
=
0
,
warmup_iter
:
int
=
10
,
...
...
@@ -102,9 +101,8 @@ def benchmark_mrope(
head_size
=
head_dim
,
rotary_dim
=
head_dim
,
max_position
=
max_position
,
base
=
rope_theta
,
is_neox_style
=
is_neox_style
,
rope_
scaling
=
rope_scaling
,
rope_
parameters
=
rope_parameters
,
dtype
=
dtype
,
).
to
(
device
=
device
)
...
...
@@ -203,9 +201,8 @@ def benchmark_mrope(
num_kv_heads
,
head_dim
,
max_position
,
rope_theta
,
is_neox_style
,
str
(
rope_
scaling
),
str
(
rope_
parameters
),
str
(
dtype
).
split
(
"."
)[
-
1
],
torch_stats
[
"mean"
],
torch_stats
[
"median"
],
...
...
@@ -255,9 +252,8 @@ if __name__ == "__main__":
"num_kv_heads"
,
"head_dim"
,
"max_position"
,
"rope_theta"
,
"is_neox_style"
,
"rope_
scaling
"
,
"rope_
parameters
"
,
"dtype"
,
"torch_mean"
,
"torch_median"
,
...
...
@@ -303,7 +299,7 @@ if __name__ == "__main__":
q_size
=
num_heads
*
head_dim
kv_size
=
num_kv_heads
*
head_dim
is_neox_style
=
True
rope_
theta
=
config
.
rope_
theta
rope_
parameters
=
config
.
rope_
parameters
max_position
=
config
.
max_position_embeddings
for
num_tokens
in
num_tokens_list
:
...
...
@@ -315,9 +311,8 @@ if __name__ == "__main__":
num_heads
=
num_heads
,
num_kv_heads
=
num_kv_heads
,
max_position
=
max_position
,
rope_theta
=
rope_theta
,
is_neox_style
=
is_neox_style
,
rope_
scaling
=
config
.
rope_scaling
,
rope_
parameters
=
rope_parameters
,
dtype
=
getattr
(
torch
,
args
.
dtype
),
seed
=
args
.
seed
,
warmup_iter
=
args
.
warmup_iter
,
...
...
benchmarks/kernels/benchmark_paged_attention.py
View file @
41199996
...
...
@@ -3,20 +3,18 @@
import
random
import
time
from
typing
import
Optional
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
import
vllm.envs
as
envs
from
vllm.utils
import
(
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
,
create_kv_caches_with_random
,
)
import
vllm.envs
as
envs
logger
=
init_logger
(
__name__
)
...
...
@@ -39,7 +37,7 @@ def main(
seed
:
int
,
do_profile
:
bool
,
device
:
str
=
"cuda"
,
kv_cache_dtype
:
Optional
[
str
]
=
None
,
kv_cache_dtype
:
str
|
None
=
None
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
...
...
benchmarks/kernels/benchmark_per_token_group_quant.py
View file @
41199996
...
...
@@ -3,8 +3,8 @@
import
argparse
import
math
from
collections.abc
import
Callable
from
contextlib
import
contextmanager
from
typing
import
Callable
from
unittest.mock
import
patch
import
torch
...
...
@@ -30,8 +30,8 @@ def _time_cuda(
fn
()
torch
.
cuda
.
synchronize
()
start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
start
=
torch
.
Event
(
enable_timing
=
True
)
end
=
torch
.
Event
(
enable_timing
=
True
)
start
.
record
()
for
_
in
range
(
bench_iters
):
...
...
benchmarks/kernels/benchmark_polynorm.py
deleted
100644 → 0
View file @
31021d81
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
import
torch
from
vllm
import
_custom_ops
as
vllm_ops
from
vllm.triton_utils
import
triton
def
polynorm_naive
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
):
orig_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
def
norm
(
x
,
eps
:
float
):
return
x
/
torch
.
sqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
eps
)
x
=
x
.
float
()
return
(
(
weight
[
0
]
*
norm
(
x
**
3
,
eps
)
+
weight
[
1
]
*
norm
(
x
**
2
,
eps
)
+
weight
[
2
]
*
norm
(
x
,
eps
)
+
bias
)
.
to
(
weight
.
dtype
)
.
view
(
orig_shape
)
)
def
polynorm_vllm
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
):
orig_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
out
=
torch
.
empty_like
(
x
)
vllm_ops
.
poly_norm
(
out
,
x
,
weight
,
bias
,
eps
)
output
=
out
output
=
output
.
view
(
orig_shape
)
return
output
def
calculate_diff
(
batch_size
,
seq_len
,
hidden_dim
):
dtype
=
torch
.
bfloat16
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_dim
,
dtype
=
dtype
,
device
=
"cuda"
)
weight
=
torch
.
ones
(
3
,
dtype
=
dtype
,
device
=
"cuda"
)
bias
=
torch
.
ones
(
1
,
dtype
=
dtype
,
device
=
"cuda"
)
output_naive
=
polynorm_naive
(
x
,
weight
,
bias
)
output_vllm
=
polynorm_vllm
(
x
,
weight
,
bias
)
if
torch
.
allclose
(
output_naive
,
output_vllm
,
atol
=
1e-2
,
rtol
=
1e-2
):
print
(
"✅ All implementations match"
)
else
:
print
(
"❌ Implementations differ"
)
batch_size_range
=
[
2
**
i
for
i
in
range
(
0
,
7
,
2
)]
seq_length_range
=
[
2
**
i
for
i
in
range
(
6
,
11
,
1
)]
dim_range
=
[
2048
,
4096
]
configs
=
list
(
itertools
.
product
(
dim_range
,
batch_size_range
,
seq_length_range
))
def
get_benchmark
():
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"dim"
,
"batch_size"
,
"seq_len"
],
x_vals
=
[
list
(
_
)
for
_
in
configs
],
line_arg
=
"provider"
,
line_vals
=
[
"naive"
,
"vllm"
],
line_names
=
[
"Naive"
,
"vLLM"
],
styles
=
[(
"blue"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"polynorm-perf"
,
args
=
{},
)
)
def
benchmark
(
dim
,
batch_size
,
seq_len
,
provider
):
dtype
=
torch
.
bfloat16
hidden_dim
=
dim
*
4
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_dim
,
dtype
=
dtype
,
device
=
"cuda"
)
weight
=
torch
.
ones
(
3
,
dtype
=
dtype
,
device
=
"cuda"
)
bias
=
torch
.
ones
(
1
,
dtype
=
dtype
,
device
=
"cuda"
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"naive"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
polynorm_naive
(
x
,
weight
,
bias
),
quantiles
=
quantiles
,
)
else
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
polynorm_vllm
(
x
,
weight
,
bias
),
quantiles
=
quantiles
,
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
return
benchmark
if
__name__
==
"__main__"
:
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
4
,
help
=
"Batch size"
,
)
parser
.
add_argument
(
"--seq-len"
,
type
=
int
,
default
=
128
,
help
=
"Sequence length"
,
)
parser
.
add_argument
(
"--hidden-dim"
,
type
=
int
,
default
=
8192
,
help
=
"Intermediate size of MLP"
,
)
parser
.
add_argument
(
"--save-path"
,
type
=
str
,
default
=
"./configs/polnorm/"
,
help
=
"Path to save polnorm benchmark results"
,
)
args
=
parser
.
parse_args
()
# Run correctness test
calculate_diff
(
batch_size
=
args
.
batch_size
,
seq_len
=
args
.
seq_len
,
hidden_dim
=
args
.
hidden_dim
,
)
benchmark
=
get_benchmark
()
# Run performance benchmark
benchmark
.
run
(
print_data
=
True
,
save_path
=
args
.
save_path
)
benchmarks/kernels/benchmark_quant.py
View file @
41199996
...
...
@@ -7,7 +7,8 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
@
torch
.
inference_mode
()
...
...
benchmarks/kernels/benchmark_reshape_and_cache.py
0 → 100644
View file @
41199996
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
random
import
time
import
torch
from
tabulate
import
tabulate
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
create_kv_caches_with_random
,
)
logger
=
init_logger
(
__name__
)
@
torch
.
inference_mode
()
def
run_benchmark
(
num_tokens
:
int
,
num_heads
:
int
,
head_size
:
int
,
block_size
:
int
,
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
str
,
num_iters
:
int
,
benchmark_mode
:
str
,
device
:
str
=
"cuda"
,
)
->
float
:
"""Return latency (seconds) for given num_tokens."""
if
kv_cache_dtype
==
"fp8"
and
head_size
%
16
:
raise
ValueError
(
"fp8 kv-cache requires head_size to be a multiple of 16."
)
current_platform
.
seed_everything
(
42
)
torch
.
set_default_device
(
device
)
# create random key / value tensors [T, H, D].
key
=
torch
.
randn
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
device
)
value
=
torch
.
randn_like
(
key
)
# prepare the slot mapping.
# each token is assigned a unique slot in the KV-cache.
num_slots
=
block_size
*
num_blocks
if
num_tokens
>
num_slots
:
raise
ValueError
(
"num_tokens cannot exceed the total number of cache slots"
)
slot_mapping_lst
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping_lst
,
dtype
=
torch
.
long
,
device
=
device
)
key_caches
,
value_caches
=
create_kv_caches_with_random
(
num_blocks
,
block_size
,
1
,
# num_layers
num_heads
,
head_size
,
kv_cache_dtype
,
dtype
,
device
=
device
,
)
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# to free unused memory
del
key_caches
,
value_caches
# compute per-kernel scaling factors for fp8 conversion (if used).
k_scale
=
(
key
.
amax
()
/
64.0
).
to
(
torch
.
float32
)
v_scale
=
(
value
.
amax
()
/
64.0
).
to
(
torch
.
float32
)
function_under_test
=
lambda
:
ops
.
reshape_and_cache
(
key
,
# noqa: F821
value
,
# noqa: F821
key_cache
,
# noqa: F821
value_cache
,
# noqa: F821
slot_mapping
,
# noqa: F821
kv_cache_dtype
,
k_scale
,
v_scale
,
)
if
benchmark_mode
==
"cudagraph"
:
g
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
g
):
function_under_test
()
torch
.
cuda
.
synchronize
()
function_under_test
=
lambda
:
g
.
replay
()
def
run_cuda_benchmark
(
n_iters
:
int
)
->
float
:
nonlocal
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
torch
.
cuda
.
synchronize
()
start
=
time
.
perf_counter
()
for
_
in
range
(
n_iters
):
function_under_test
()
torch
.
cuda
.
synchronize
()
end
=
time
.
perf_counter
()
return
(
end
-
start
)
/
n_iters
# warm-up
run_cuda_benchmark
(
3
)
lat
=
run_cuda_benchmark
(
num_iters
)
# free tensors to mitigate OOM when sweeping
del
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
torch
.
cuda
.
empty_cache
()
return
lat
def
main
(
args
):
rows
=
[]
for
exp
in
range
(
1
,
17
):
n_tok
=
2
**
exp
lat
=
run_benchmark
(
num_tokens
=
n_tok
,
num_heads
=
args
.
num_heads
,
head_size
=
args
.
head_size
,
block_size
=
args
.
block_size
,
num_blocks
=
args
.
num_blocks
,
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
args
.
dtype
],
kv_cache_dtype
=
args
.
kv_cache_dtype
,
num_iters
=
args
.
iters
,
benchmark_mode
=
args
.
mode
,
device
=
"cuda"
,
)
rows
.
append
([
n_tok
,
lat
*
1e6
])
# convert to microseconds
print
(
f
"Benchmark results for implementation cuda (measuring with
{
args
.
mode
}
):"
)
print
(
tabulate
(
rows
,
headers
=
[
"num_tokens"
,
"latency (µs)"
],
floatfmt
=
".3f"
))
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
"--num-heads"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--head-size"
,
type
=
int
,
choices
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
],
default
=
128
,
)
parser
.
add_argument
(
"--block-size"
,
type
=
int
,
choices
=
[
16
,
32
],
default
=
16
)
parser
.
add_argument
(
"--num-blocks"
,
type
=
int
,
default
=
128
*
128
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"bfloat16"
,
)
parser
.
add_argument
(
"--kv-cache-dtype"
,
type
=
str
,
choices
=
[
"auto"
,
"fp8"
],
default
=
"auto"
,
)
parser
.
add_argument
(
"--iters"
,
type
=
int
,
default
=
200
)
parser
.
add_argument
(
"--mode"
,
type
=
str
,
choices
=
[
"cudagraph"
,
"no_graph"
],
default
=
"cudagraph"
,
)
args
=
parser
.
parse_args
()
main
(
args
)
Prev
1
2
3
4
5
6
7
8
9
10
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment