Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
41199996
Commit
41199996
authored
Dec 13, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.12.0' into v0.12.0-dev
parents
31021d81
4fd9d6a8
Changes
380
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1847 additions
and
270 deletions
+1847
-270
benchmarks/kernels/bench_per_token_quant_fp8.py
benchmarks/kernels/bench_per_token_quant_fp8.py
+3
-2
benchmarks/kernels/benchmark_activation.py
benchmarks/kernels/benchmark_activation.py
+2
-1
benchmarks/kernels/benchmark_bitblas.py
benchmarks/kernels/benchmark_bitblas.py
+1
-1
benchmarks/kernels/benchmark_cutlass_fp4_moe.py
benchmarks/kernels/benchmark_cutlass_fp4_moe.py
+1
-1
benchmarks/kernels/benchmark_cutlass_moe_fp8.py
benchmarks/kernels/benchmark_cutlass_moe_fp8.py
+3
-3
benchmarks/kernels/benchmark_device_communicators.py
benchmarks/kernels/benchmark_device_communicators.py
+4
-4
benchmarks/kernels/benchmark_fused_collective.py
benchmarks/kernels/benchmark_fused_collective.py
+1129
-0
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
+3
-3
benchmarks/kernels/benchmark_layernorm.py
benchmarks/kernels/benchmark_layernorm.py
+2
-1
benchmarks/kernels/benchmark_lora.py
benchmarks/kernels/benchmark_lora.py
+457
-40
benchmarks/kernels/benchmark_machete.py
benchmarks/kernels/benchmark_machete.py
+17
-17
benchmarks/kernels/benchmark_marlin.py
benchmarks/kernels/benchmark_marlin.py
+3
-3
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+27
-10
benchmarks/kernels/benchmark_moe_permute_unpermute.py
benchmarks/kernels/benchmark_moe_permute_unpermute.py
+6
-6
benchmarks/kernels/benchmark_mrope.py
benchmarks/kernels/benchmark_mrope.py
+8
-13
benchmarks/kernels/benchmark_paged_attention.py
benchmarks/kernels/benchmark_paged_attention.py
+4
-6
benchmarks/kernels/benchmark_per_token_group_quant.py
benchmarks/kernels/benchmark_per_token_group_quant.py
+3
-3
benchmarks/kernels/benchmark_polynorm.py
benchmarks/kernels/benchmark_polynorm.py
+0
-155
benchmarks/kernels/benchmark_quant.py
benchmarks/kernels/benchmark_quant.py
+2
-1
benchmarks/kernels/benchmark_reshape_and_cache.py
benchmarks/kernels/benchmark_reshape_and_cache.py
+172
-0
No files found.
Too many changes to show.
To preserve performance only
380 of 380+
files are displayed.
Plain diff
Email patch
benchmarks/kernels/bench_per_token_quant_fp8.py
View file @
41199996
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
import
itertools
from
typing
import
Callable
from
collections.abc
import
Callable
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
pandas
as
pd
import
pandas
as
pd
...
@@ -10,7 +10,8 @@ import torch
...
@@ -10,7 +10,8 @@ import torch
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.triton_utils
import
triton
from
vllm.triton_utils
import
triton
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
def
with_triton_mode
(
fn
):
def
with_triton_mode
(
fn
):
...
...
benchmarks/kernels/benchmark_activation.py
View file @
41199996
...
@@ -10,7 +10,8 @@ import vllm.model_executor.layers.activation # noqa F401
...
@@ -10,7 +10,8 @@ import vllm.model_executor.layers.activation # noqa F401
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
triton
from
vllm.triton_utils
import
triton
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
batch_size_range
=
[
1
,
16
,
32
,
64
,
128
]
batch_size_range
=
[
1
,
16
,
32
,
64
,
128
]
seq_len_range
=
[
1
,
16
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
]
seq_len_range
=
[
1
,
16
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
]
...
...
benchmarks/kernels/benchmark_bitblas.py
View file @
41199996
...
@@ -28,7 +28,7 @@ except ImportError as e:
...
@@ -28,7 +28,7 @@ except ImportError as e:
from
bitblas
import
Matmul
,
MatmulConfig
,
auto_detect_nvidia_target
from
bitblas
import
Matmul
,
MatmulConfig
,
auto_detect_nvidia_target
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
parser
=
FlexibleArgumentParser
(
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark BitBLAS int4 on a specific target."
description
=
"Benchmark BitBLAS int4 on a specific target."
...
...
benchmarks/kernels/benchmark_cutlass_fp4_moe.py
View file @
41199996
...
@@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
...
@@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
cutlass_moe_fp4
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
cutlass_moe_fp4
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
,
fused_topk
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
,
fused_topk
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
WEIGHT_SHAPES_MOE
=
{
WEIGHT_SHAPES_MOE
=
{
"nvidia/DeepSeek-R1-FP4"
:
[
"nvidia/DeepSeek-R1-FP4"
:
[
...
...
benchmarks/kernels/benchmark_cutlass_moe_fp8.py
View file @
41199996
...
@@ -14,7 +14,7 @@ from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_confi
...
@@ -14,7 +14,7 @@ from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_confi
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
cutlass_moe_fp8
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
cutlass_moe_fp8
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
,
fused_topk
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
,
fused_topk
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
# Weight shapes for different models: [num_experts, topk, hidden_size,
# Weight shapes for different models: [num_experts, topk, hidden_size,
# intermediate_size]
# intermediate_size]
...
@@ -255,8 +255,8 @@ def bench_run(
...
@@ -255,8 +255,8 @@ def bench_run(
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
# Timing
# Timing
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
start_event
=
torch
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
Event
(
enable_timing
=
True
)
latencies
=
[]
latencies
=
[]
for
_
in
range
(
num_iters
):
for
_
in
range
(
num_iters
):
...
...
benchmarks/kernels/benchmark_device_communicators.py
View file @
41199996
...
@@ -22,8 +22,8 @@ Example:
...
@@ -22,8 +22,8 @@ Example:
import
json
import
json
import
os
import
os
import
time
import
time
from
collections.abc
import
Callable
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
from
typing
import
Callable
,
Optional
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -39,7 +39,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import (
...
@@ -39,7 +39,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import (
)
)
from
vllm.distributed.device_communicators.symm_mem
import
SymmMemCommunicator
from
vllm.distributed.device_communicators.symm_mem
import
SymmMemCommunicator
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -264,12 +264,12 @@ class CommunicatorBenchmark:
...
@@ -264,12 +264,12 @@ class CommunicatorBenchmark:
def
benchmark_allreduce_single
(
def
benchmark_allreduce_single
(
self
,
self
,
sequence_length
:
int
,
sequence_length
:
int
,
allreduce_fn
:
Callable
[[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]
],
allreduce_fn
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
|
None
],
should_use_fn
:
Callable
[[
torch
.
Tensor
],
bool
],
should_use_fn
:
Callable
[[
torch
.
Tensor
],
bool
],
context
,
context
,
num_warmup
:
int
,
num_warmup
:
int
,
num_trials
:
int
,
num_trials
:
int
,
)
->
Optional
[
float
]
:
)
->
float
|
None
:
"""Benchmark method with CUDA graph optimization."""
"""Benchmark method with CUDA graph optimization."""
try
:
try
:
# Create test tensor (2D: sequence_length x hidden_size)
# Create test tensor (2D: sequence_length x hidden_size)
...
...
benchmarks/kernels/benchmark_fused_collective.py
0 → 100644
View file @
41199996
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark for FlashInfer fused collective operations vs standard operations.
This benchmark compares:
1. FlashInfer's trtllm_allreduce_fusion (fused allreduce + rmsnorm + optional quant)
2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations
Usage with torchrun:
torchrun --nproc_per_node=2 benchmark_fused_collective.py
"""
import
argparse
import
itertools
import
os
import
time
import
pandas
as
pd
import
torch
# type: ignore
import
torch.distributed
as
dist
# type: ignore
from
vllm.config.vllm
import
CompilationConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.distributed
import
(
get_tp_group
,
tensor_model_parallel_all_reduce
,
)
from
vllm.distributed.parallel_state
import
(
graph_capture
,
init_distributed_environment
,
initialize_model_parallel
,
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.layernorm
import
RMSNorm
# noqa
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
# noqa
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
# noqa
from
vllm.platforms
import
current_platform
# noqa
RMS_NORM_OP
=
torch
.
ops
.
_C
.
rms_norm
FUSED_ADD_RMS_NORM_OP
=
torch
.
ops
.
_C
.
fused_add_rms_norm
RMS_NORM_STATIC_FP8_QUANT_OP
=
torch
.
ops
.
_C
.
rms_norm_static_fp8_quant
FUSED_ADD_RMS_NORM_STATIC_FP8_QUANT_OP
=
(
torch
.
ops
.
_C
.
fused_add_rms_norm_static_fp8_quant
)
SCALED_FP4_QUANT_OP
=
torch
.
ops
.
_C
.
scaled_fp4_quant
logger
=
init_logger
(
__name__
)
# Try to import FlashInfer
try
:
import
flashinfer.comm
as
flashinfer_comm
# type: ignore
if
not
hasattr
(
flashinfer_comm
,
"trtllm_allreduce_fusion"
):
flashinfer_comm
=
None
logger
.
warning
(
"FlashInfer comm module found but missing trtllm_allreduce_fusion"
)
except
ImportError
:
flashinfer_comm
=
None
logger
.
warning
(
"FlashInfer not found, only benchmarking standard operations"
)
# Constants
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
MiB
=
1024
*
1024
# FlashInfer max sizes per world size
# Enable 64MB for 2, 4, 8 world sizes to verify large input sizes
# use --disable-oneshot to disable oneshot mode for very large input sizes
_FI_MAX_SIZES
=
{
2
:
64
*
MiB
,
# 64MB
4
:
64
*
MiB
,
# 64MB
8
:
64
*
MiB
,
# 64MB
}
# Global workspace tensor for FlashInfer
_FI_WORKSPACE_TENSOR
=
None
def
setup_flashinfer_workspace
(
world_size
:
int
,
rank
:
int
,
hidden_dim
:
int
,
max_token_num
:
int
,
use_fp32_lamport
:
bool
=
False
,
):
"""Setup FlashInfer workspace for fused allreduce operations."""
global
_FI_WORKSPACE_TENSOR
if
flashinfer_comm
is
None
:
return
None
,
None
if
world_size
not
in
_FI_MAX_SIZES
:
logger
.
warning
(
"FlashInfer not supported for world size %s"
,
world_size
)
return
None
,
None
try
:
# Create IPC workspace
ipc_handles
,
workspace_tensor
=
(
flashinfer_comm
.
trtllm_create_ipc_workspace_for_all_reduce_fusion
(
tp_rank
=
rank
,
tp_size
=
world_size
,
max_token_num
=
max_token_num
,
hidden_dim
=
hidden_dim
,
group
=
get_tp_group
().
device_group
,
use_fp32_lamport
=
use_fp32_lamport
,
)
)
_FI_WORKSPACE_TENSOR
=
workspace_tensor
return
ipc_handles
,
workspace_tensor
except
Exception
as
e
:
logger
.
error
(
"Failed to setup FlashInfer workspace: %s"
,
e
)
return
None
,
None
def
cleanup_flashinfer_workspace
(
ipc_handles
):
"""Cleanup FlashInfer workspace."""
if
flashinfer_comm
is
None
or
ipc_handles
is
None
:
return
try
:
group
=
get_tp_group
().
device_group
flashinfer_comm
.
trtllm_destroy_ipc_workspace_for_all_reduce
(
ipc_handles
,
group
)
except
Exception
as
e
:
logger
.
error
(
"Failed to cleanup FlashInfer workspace: %s"
,
e
)
class
FlashInferFusedAllReduceParams
:
"""Parameters for FlashInfer fused allreduce operations."""
def
__init__
(
self
,
rank
:
int
,
world_size
:
int
,
use_fp32_lamport
:
bool
=
False
,
max_token_num
:
int
=
1024
,
):
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
use_fp32_lamport
=
use_fp32_lamport
self
.
trigger_completion_at_end
=
True
self
.
launch_with_pdl
=
True
self
.
fp32_acc
=
True
self
.
max_token_num
=
max_token_num
def
get_trtllm_fused_allreduce_kwargs
(
self
):
return
{
"world_rank"
:
self
.
rank
,
"world_size"
:
self
.
world_size
,
"launch_with_pdl"
:
self
.
launch_with_pdl
,
"trigger_completion_at_end"
:
self
.
trigger_completion_at_end
,
"fp32_acc"
:
self
.
fp32_acc
,
}
def
flashinfer_fused_allreduce_rmsnorm
(
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
rms_gamma
:
torch
.
Tensor
,
rms_eps
:
float
,
allreduce_params
:
"FlashInferFusedAllReduceParams"
,
use_oneshot
:
bool
,
norm_out
:
torch
.
Tensor
|
None
=
None
,
):
"""FlashInfer fused allreduce + rmsnorm operation."""
if
flashinfer_comm
is
None
or
_FI_WORKSPACE_TENSOR
is
None
:
raise
RuntimeError
(
"FlashInfer not available or workspace not initialized"
)
if
norm_out
is
None
:
norm_out
=
input_tensor
residual_out
=
residual
else
:
residual_out
=
input_tensor
flashinfer_comm
.
trtllm_allreduce_fusion
(
allreduce_in
=
input_tensor
,
token_num
=
input_tensor
.
shape
[
0
],
residual_in
=
residual
,
residual_out
=
residual_out
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
hidden_dim
=
input_tensor
.
shape
[
-
1
],
workspace_ptrs
=
_FI_WORKSPACE_TENSOR
,
pattern_code
=
flashinfer_comm
.
AllReduceFusionPattern
.
kARResidualRMSNorm
,
allreduce_out
=
None
,
quant_out
=
None
,
scale_out
=
None
,
layout_code
=
flashinfer_comm
.
QuantizationSFLayout
.
SWIZZLED_128x4
,
scale_factor
=
None
,
use_oneshot
=
use_oneshot
,
**
allreduce_params
.
get_trtllm_fused_allreduce_kwargs
(),
)
def
flashinfer_fused_allreduce_rmsnorm_fp8_quant
(
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
rms_gamma
:
torch
.
Tensor
,
rms_eps
:
float
,
scale_factor
:
torch
.
Tensor
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
use_oneshot
:
bool
=
True
,
norm_out
:
torch
.
Tensor
|
None
=
None
,
quant_out
:
torch
.
Tensor
|
None
=
None
,
):
"""FlashInfer fused allreduce + rmsnorm + FP8 quantization."""
if
flashinfer_comm
is
None
or
_FI_WORKSPACE_TENSOR
is
None
:
raise
RuntimeError
(
"FlashInfer not available or workspace not initialized"
)
if
norm_out
is
None
:
norm_out
=
input_tensor
residual_out
=
residual
else
:
residual_out
=
input_tensor
flashinfer_comm
.
trtllm_allreduce_fusion
(
allreduce_in
=
input_tensor
,
token_num
=
input_tensor
.
shape
[
0
],
residual_in
=
residual
,
residual_out
=
residual_out
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
hidden_dim
=
input_tensor
.
shape
[
-
1
],
workspace_ptrs
=
_FI_WORKSPACE_TENSOR
,
pattern_code
=
flashinfer_comm
.
AllReduceFusionPattern
.
kARResidualRMSNormFP8Quant
,
allreduce_out
=
None
,
quant_out
=
quant_out
,
scale_out
=
None
,
layout_code
=
flashinfer_comm
.
QuantizationSFLayout
.
SWIZZLED_128x4
,
scale_factor
=
scale_factor
,
use_oneshot
=
use_oneshot
,
**
allreduce_params
.
get_trtllm_fused_allreduce_kwargs
(),
)
def
flashinfer_fused_allreduce_rmsnorm_fp4_quant
(
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
rms_gamma
:
torch
.
Tensor
,
rms_eps
:
float
,
input_global_scale
:
torch
.
Tensor
,
allreduce_params
:
FlashInferFusedAllReduceParams
,
quant_out
:
torch
.
Tensor
,
use_oneshot
:
bool
,
output_scale
:
torch
.
Tensor
,
norm_out
:
torch
.
Tensor
|
None
=
None
,
):
"""FlashInfer fused allreduce + rmsnorm + FP4 quantization."""
if
flashinfer_comm
is
None
or
_FI_WORKSPACE_TENSOR
is
None
:
raise
RuntimeError
(
"FlashInfer not available or workspace not initialized"
)
if
norm_out
is
None
:
norm_out
=
input_tensor
residual_out
=
residual
else
:
residual_out
=
input_tensor
flashinfer_comm
.
trtllm_allreduce_fusion
(
allreduce_in
=
input_tensor
,
token_num
=
input_tensor
.
shape
[
0
],
residual_in
=
residual
,
residual_out
=
residual_out
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
hidden_dim
=
input_tensor
.
shape
[
-
1
],
workspace_ptrs
=
_FI_WORKSPACE_TENSOR
,
pattern_code
=
flashinfer_comm
.
AllReduceFusionPattern
.
kARResidualRMSNormFP4Quant
,
allreduce_out
=
None
,
quant_out
=
quant_out
,
scale_out
=
output_scale
,
layout_code
=
flashinfer_comm
.
QuantizationSFLayout
.
SWIZZLED_128x4
,
scale_factor
=
input_global_scale
,
use_oneshot
=
use_oneshot
,
**
allreduce_params
.
get_trtllm_fused_allreduce_kwargs
(),
)
class
VllmFusedAllreduce
:
def
__init__
(
self
,
hidden_dim
,
dtype
):
self
.
rms_eps
=
1e-6
self
.
rms_norm
=
RMSNorm
(
hidden_dim
,
eps
=
self
.
rms_eps
,
dtype
=
dtype
)
self
.
fp8_quant
=
QuantFP8
(
static
=
True
,
group_shape
=
GroupShape
.
PER_TENSOR
,
)
def
allreduce_rmsnorm
(
self
,
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
):
allreduce_out
=
tensor_model_parallel_all_reduce
(
input_tensor
)
return
self
.
rms_norm
(
allreduce_out
,
residual
)
def
allreduce_rmsnorm_fp8_quant
(
self
,
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
scale_factor
:
torch
.
Tensor
,
):
allreduce_out
=
tensor_model_parallel_all_reduce
(
input_tensor
)
rms_out
=
self
.
rms_norm
(
allreduce_out
,
residual
)
if
residual
is
None
:
quant_out
=
self
.
fp8_quant
(
rms_out
,
scale_factor
)
return
quant_out
else
:
rms_out
,
residual_out
=
rms_out
quant_out
=
self
.
fp8_quant
(
rms_out
,
scale_factor
)
return
quant_out
,
residual_out
def
allreduce_rmsnorm_fp4_quant
(
self
,
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
input_global_scale
:
torch
.
Tensor
,
quant_out
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
):
allreduce_out
=
tensor_model_parallel_all_reduce
(
input_tensor
)
rms_out
=
self
.
rms_norm
(
allreduce_out
,
residual
)
if
residual
is
None
:
SCALED_FP4_QUANT_OP
(
quant_out
,
rms_out
,
output_scale
,
input_global_scale
)
return
quant_out
,
output_scale
else
:
rms_out
,
residual_out
=
rms_out
SCALED_FP4_QUANT_OP
(
quant_out
,
rms_out
,
output_scale
,
input_global_scale
)
return
quant_out
,
residual_out
,
output_scale
def
create_test_tensors
(
num_tokens
:
int
,
hidden_dim
:
int
,
dtype
:
torch
.
dtype
,
use_residual
:
bool
=
True
):
"""Create test tensors for benchmarking."""
input_tensor
=
torch
.
randn
(
num_tokens
,
hidden_dim
,
dtype
=
dtype
)
residual
=
(
torch
.
randn_like
(
input_tensor
)
if
use_residual
else
torch
.
zeros_like
(
input_tensor
)
)
rms_gamma
=
torch
.
ones
(
hidden_dim
,
dtype
=
dtype
)
norm_out
=
None
if
use_residual
else
torch
.
empty_like
(
input_tensor
)
# Quantization scales
scale_fp8
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
scale_fp4
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
quant_out_fp8
=
torch
.
empty_like
(
input_tensor
,
dtype
=
FP8_DTYPE
)
# Pre-allocate FP4 output tensors (to avoid allocation overhead in benchmarks)
fp4_quant_out
=
torch
.
empty
((
num_tokens
,
hidden_dim
//
2
),
dtype
=
torch
.
uint8
)
fp4_output_scale
=
torch
.
empty
((
128
,
4
),
dtype
=
torch
.
int32
)
return
(
input_tensor
,
norm_out
,
residual
,
rms_gamma
,
scale_fp8
,
quant_out_fp8
,
scale_fp4
,
fp4_quant_out
,
fp4_output_scale
,
)
def
benchmark_operation
(
operation_func
,
*
args
,
warmup
:
int
=
5
,
trials
:
int
=
20
,
**
kwargs
):
"""Benchmark a single operation using CUDA graphs."""
# Warmup before graph capture
for
_
in
range
(
warmup
):
operation_func
(
*
args
,
**
kwargs
)
torch
.
cuda
.
synchronize
()
# Create CUDA graph
graph
=
torch
.
cuda
.
CUDAGraph
()
num_op_per_cudagraph
=
10
# Use vLLM's graph_capture to make tensor_model_parallel_all_reduce graph-safe
device
=
torch
.
device
(
f
"cuda:
{
torch
.
cuda
.
current_device
()
}
"
)
with
graph_capture
(
device
=
device
),
torch
.
cuda
.
graph
(
graph
):
for
_
in
range
(
num_op_per_cudagraph
):
operation_func
(
*
args
,
**
kwargs
)
# Graph warmup
torch
.
cuda
.
synchronize
()
for
_
in
range
(
warmup
):
graph
.
replay
()
# Benchmark with CUDA graph
torch
.
cuda
.
synchronize
()
start_time
=
time
.
perf_counter
()
for
_
in
range
(
trials
//
num_op_per_cudagraph
):
# operation_func(*args, **kwargs)
graph
.
replay
()
torch
.
cuda
.
synchronize
()
end_time
=
time
.
perf_counter
()
avg_time_ms
=
((
end_time
-
start_time
)
/
trials
)
*
1000
return
avg_time_ms
def
run_benchmarks
(
num_tokens
:
int
,
hidden_dim
:
int
,
dtype
:
torch
.
dtype
,
use_residual
:
bool
,
allreduce_params
:
FlashInferFusedAllReduceParams
|
None
,
quant_modes
:
set
[
str
],
no_oneshot
:
bool
,
):
"""Run all benchmarks for given configuration.
Args:
quant_mode: "none", "fp8_only", "fp4_only", or "all"
"""
(
input_tensor
,
norm_out
,
residual
,
rms_gamma
,
scale_fp8
,
quant_out_fp8
,
scale_fp4
,
fp4_quant_out
,
fp4_output_scale
,
)
=
create_test_tensors
(
num_tokens
,
hidden_dim
,
dtype
,
use_residual
)
rms_eps
=
1e-6
results
=
{}
vllm_fused_allreduce
=
VllmFusedAllreduce
(
hidden_dim
,
dtype
)
use_oneshot_options
=
[
False
]
if
no_oneshot
else
[
True
,
False
]
# Create RMSNorm and QuantFP8 layers once for native benchmarks
if
"none"
in
quant_modes
:
# Standard AllReduce + RMSNorm
for
custom_op
in
[
"-rms_norm"
,
"+rms_norm"
]:
with
set_current_vllm_config
(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
custom_op
]))
):
try
:
suffix
=
(
"_custom_rms_norm"
if
"+"
in
custom_op
else
"_native_rms_norm"
)
time_ms
=
benchmark_operation
(
vllm_fused_allreduce
.
allreduce_rmsnorm
,
input_tensor
,
residual
=
residual
,
)
results
[
f
"standard_allreduce_
{
suffix
}
"
]
=
time_ms
except
Exception
as
e
:
logger
.
error
(
"Standard AllReduce+RMSNorm failed: %s"
,
e
)
results
[
f
"standard_allreduce_
{
suffix
}
"
]
=
float
(
"inf"
)
# Standard AllReduce + RMSNorm Native Compiled
with
set_current_vllm_config
(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
"-rms_norm"
]))
):
try
:
standard_allreduce_rmsnorm_native_compiled
=
torch
.
compile
(
vllm_fused_allreduce
.
allreduce_rmsnorm
,
fullgraph
=
True
,
dynamic
=
False
,
)
time_ms
=
benchmark_operation
(
standard_allreduce_rmsnorm_native_compiled
,
input_tensor
,
residual
=
residual
,
)
results
[
"standard_allreduce_rmsnorm_native_compiled"
]
=
time_ms
except
Exception
as
e
:
logger
.
error
(
"Standard AllReduce+RMSNorm Native Compiled failed: %s"
,
e
)
results
[
"standard_allreduce_rmsnorm_native_compiled"
]
=
float
(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm Oneshot/Twoshot
if
flashinfer_comm
is
not
None
and
allreduce_params
is
not
None
:
for
use_oneshot
in
use_oneshot_options
:
suffix
=
"_oneshot"
if
use_oneshot
else
"_twoshot"
try
:
time_ms
=
benchmark_operation
(
flashinfer_fused_allreduce_rmsnorm
,
input_tensor
,
residual
=
residual
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
allreduce_params
=
allreduce_params
,
use_oneshot
=
use_oneshot
,
)
results
[
f
"flashinfer_fused_allreduce_rmsnorm
{
suffix
}
"
]
=
time_ms
except
Exception
as
e
:
logger
.
error
(
"FlashInfer Fused AllReduce+RMSNorm failed: %s"
,
e
)
results
[
f
"flashinfer_fused_allreduce_rmsnorm
{
suffix
}
"
]
=
float
(
"inf"
)
if
"fp8"
in
quant_modes
:
# Standard AllReduce + RMSNorm + FP8 Quant
for
rms_norm_custom_op
in
[
"-rms_norm"
,
"+rms_norm"
]:
suffix
=
(
"_custom_rms_norm"
if
"+"
in
rms_norm_custom_op
else
"_native_rms_norm"
)
for
quant_fp8_custom_op
in
[
"-quant_fp8"
,
"+quant_fp8"
]:
suffix
+=
(
"_custom_quant_fp8"
if
"+"
in
quant_fp8_custom_op
else
"_native_quant_fp8"
)
with
set_current_vllm_config
(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
rms_norm_custom_op
,
quant_fp8_custom_op
]
)
)
):
try
:
time_ms
=
benchmark_operation
(
vllm_fused_allreduce
.
allreduce_rmsnorm_fp8_quant
,
input_tensor
,
residual
=
residual
,
scale_factor
=
scale_fp8
,
)
results
[
f
"standard_allreduce
{
suffix
}
"
]
=
time_ms
except
Exception
as
e
:
logger
.
error
(
"Standard AllReduce+RMSNorm+FP8 failed: %s"
,
e
)
results
[
f
"standard_allreduce
{
suffix
}
"
]
=
float
(
"inf"
)
# Standard AllReduce + RMSNorm + FP8 Quant Native Compiled
with
set_current_vllm_config
(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
"-rms_norm"
,
"-quant_fp8"
]
)
)
):
try
:
standard_allreduce_rmsnorm_fp8_quant_native_compiled
=
torch
.
compile
(
vllm_fused_allreduce
.
allreduce_rmsnorm_fp8_quant
,
fullgraph
=
True
,
dynamic
=
False
,
)
time_ms
=
benchmark_operation
(
standard_allreduce_rmsnorm_fp8_quant_native_compiled
,
input_tensor
,
residual
=
residual
,
scale_factor
=
scale_fp8
,
)
results
[
"standard_allreduce_rmsnorm_fp8_quant_native_compiled"
]
=
(
time_ms
)
except
Exception
as
e
:
logger
.
error
(
"Standard AllReduce+RMSNorm+FP8 Native Compiled failed: %s"
,
e
)
results
[
"standard_allreduce_rmsnorm_fp8_quant_native_compiled"
]
=
float
(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot
if
flashinfer_comm
is
not
None
and
allreduce_params
is
not
None
:
for
use_oneshot
in
use_oneshot_options
:
suffix
=
"_oneshot"
if
use_oneshot
else
"_twoshot"
try
:
time_ms
=
benchmark_operation
(
flashinfer_fused_allreduce_rmsnorm_fp8_quant
,
input_tensor
,
norm_out
=
norm_out
,
residual
=
residual
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
scale_factor
=
scale_fp8
,
quant_out
=
quant_out_fp8
,
allreduce_params
=
allreduce_params
,
use_oneshot
=
use_oneshot
,
)
results
[
f
"flashinfer_fused_allreduce_rmsnorm_fp8_quant
{
suffix
}
"
]
=
(
time_ms
)
except
Exception
as
e
:
logger
.
error
(
"FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s"
,
e
,
)
results
[
f
"flashinfer_fused_allreduce_rmsnorm_fp8_quant
{
suffix
}
"
]
=
(
float
(
"inf"
)
)
if
"fp4"
in
quant_modes
and
current_platform
.
has_device_capability
(
100
):
# Standard AllReduce + RMSNorm + FP4 Quant
for
rms_norm_custom_op
in
[
"-rms_norm"
,
"+rms_norm"
]:
suffix
=
(
"_custom_rms_norm"
if
"+"
in
rms_norm_custom_op
else
"_native_rms_norm"
)
with
set_current_vllm_config
(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
rms_norm_custom_op
]
)
)
):
try
:
time_ms
=
benchmark_operation
(
vllm_fused_allreduce
.
allreduce_rmsnorm_fp4_quant
,
input_tensor
,
residual
=
residual
,
input_global_scale
=
scale_fp4
,
quant_out
=
fp4_quant_out
,
output_scale
=
fp4_output_scale
,
)
results
[
f
"standard_allreduce_
{
suffix
}
_fp4_quant"
]
=
time_ms
except
Exception
as
e
:
logger
.
error
(
"Standard AllReduce+RMSNorm+FP4 failed: %s"
,
e
)
results
[
f
"standard_allreduce_
{
suffix
}
_fp4_quant"
]
=
float
(
"inf"
)
# Standard AllReduce + RMSNorm + FP4 Quant Native Compiled
with
set_current_vllm_config
(
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
[
"-rms_norm"
]))
):
try
:
standard_allreduce_rmsnorm_fp4_quant_native_compiled
=
torch
.
compile
(
vllm_fused_allreduce
.
allreduce_rmsnorm_fp4_quant
,
fullgraph
=
True
,
dynamic
=
False
,
)
time_ms
=
benchmark_operation
(
standard_allreduce_rmsnorm_fp4_quant_native_compiled
,
input_tensor
,
residual
=
residual
,
quant_out
=
fp4_quant_out
,
input_global_scale
=
scale_fp4
,
output_scale
=
fp4_output_scale
,
)
results
[
"standard_allreduce_rmsnorm_fp4_quant_native_compiled"
]
=
(
time_ms
)
except
Exception
as
e
:
logger
.
error
(
"Standard AllReduce+RMSNorm+FP4 Native Compiled failed: %s"
,
e
)
results
[
"standard_allreduce_rmsnorm_fp4_quant_native_compiled"
]
=
float
(
"inf"
)
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot
if
flashinfer_comm
is
not
None
and
allreduce_params
is
not
None
:
for
use_oneshot
in
use_oneshot_options
:
suffix
=
"_oneshot"
if
use_oneshot
else
"_twoshot"
try
:
time_ms
=
benchmark_operation
(
flashinfer_fused_allreduce_rmsnorm_fp4_quant
,
input_tensor
,
residual
=
residual
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
input_global_scale
=
scale_fp4
,
allreduce_params
=
allreduce_params
,
quant_out
=
fp4_quant_out
,
output_scale
=
fp4_output_scale
,
use_oneshot
=
use_oneshot
,
)
results
[
f
"flashinfer_fused_allreduce_rmsnorm_fp4_quant
{
suffix
}
"
]
=
(
time_ms
)
except
Exception
as
e
:
logger
.
error
(
"FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s"
,
e
,
)
results
[
f
"flashinfer_fused_allreduce_rmsnorm_fp4_quant
{
suffix
}
"
]
=
(
float
(
"inf"
)
)
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Two-shot
if
flashinfer_comm
is
not
None
and
allreduce_params
is
not
None
:
try
:
time_ms
=
benchmark_operation
(
flashinfer_fused_allreduce_rmsnorm_fp4_quant
,
input_tensor
,
residual
=
residual
,
norm_out
=
norm_out
,
rms_gamma
=
rms_gamma
,
rms_eps
=
rms_eps
,
input_global_scale
=
scale_fp4
,
allreduce_params
=
allreduce_params
,
quant_out
=
fp4_quant_out
,
output_scale
=
fp4_output_scale
,
use_oneshot
=
False
,
)
results
[
"flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"
]
=
(
time_ms
)
except
Exception
as
e
:
logger
.
error
(
"FlashInfer Fused AllReduce+RMSNorm+FP4 Two-shot failed: %s"
,
e
,
)
results
[
"flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"
]
=
float
(
"inf"
)
return
results
def
prepare_results_with_speedups
(
results_dict
):
"""Prepare results with speedup calculations based on dynamic baseline selection."""
prepared_results
=
[]
# Determine the fastest baseline for each operation type
def
get_fastest_baseline
(
op_name
,
results_dict
):
"""Get the fastest baseline between standard and native_compiled versions."""
if
"fp8_quant"
in
op_name
:
candidates
=
[
"standard_allreduce_rmsnorm_fp8_quant"
,
"standard_allreduce_rmsnorm_fp8_quant_native_compiled"
,
]
elif
"fp4_quant"
in
op_name
:
candidates
=
[
"standard_allreduce_rmsnorm_fp4_quant"
,
"standard_allreduce_rmsnorm_fp4_quant_native_compiled"
,
]
else
:
candidates
=
[
"standard_allreduce_rmsnorm"
,
"standard_allreduce_rmsnorm_native_compiled"
,
]
# Find the fastest among available candidates
fastest_time
=
float
(
"inf"
)
fastest_baseline
=
None
for
candidate
in
candidates
:
if
(
candidate
in
results_dict
and
results_dict
[
candidate
]
!=
float
(
"inf"
)
and
results_dict
[
candidate
]
<
fastest_time
):
fastest_time
=
results_dict
[
candidate
]
fastest_baseline
=
candidate
return
fastest_baseline
# Create dynamic baseline mapping
dynamic_baseline_mapping
=
{}
for
op_name
in
results_dict
:
if
(
op_name
.
startswith
(
"flashinfer_"
)
or
op_name
.
startswith
(
"standard_"
)
and
not
op_name
.
endswith
(
"_native_compiled"
)
):
dynamic_baseline_mapping
[
op_name
]
=
get_fastest_baseline
(
op_name
,
results_dict
)
for
op_name
,
time_ms
in
results_dict
.
items
():
if
time_ms
==
float
(
"inf"
):
speedup_str
=
"FAILED"
time_str
=
"FAILED"
else
:
time_str
=
f
"
{
time_ms
:.
3
f
}
"
# Find the appropriate baseline for this operation
baseline_op
=
dynamic_baseline_mapping
.
get
(
op_name
)
if
baseline_op
and
baseline_op
in
results_dict
:
baseline_time
=
results_dict
[
baseline_op
]
if
baseline_time
!=
float
(
"inf"
)
and
baseline_time
>
0
:
speedup
=
baseline_time
/
time_ms
speedup_str
=
f
"
{
speedup
:.
2
f
}
x"
else
:
speedup_str
=
"N/A"
else
:
# For baseline operations, determine if this is the fastest baseline
if
op_name
.
endswith
(
"_native_compiled"
)
or
(
op_name
.
startswith
(
"standard_"
)
and
not
op_name
.
endswith
(
"_native_compiled"
)
):
fastest_baseline
=
get_fastest_baseline
(
op_name
,
results_dict
)
if
fastest_baseline
==
op_name
:
speedup_str
=
"baseline"
else
:
if
fastest_baseline
and
fastest_baseline
in
results_dict
:
baseline_time
=
results_dict
[
fastest_baseline
]
if
baseline_time
!=
float
(
"inf"
)
and
baseline_time
>
0
:
speedup
=
baseline_time
/
time_ms
speedup_str
=
f
"
{
speedup
:.
2
f
}
x"
else
:
speedup_str
=
"N/A"
else
:
speedup_str
=
"N/A"
else
:
speedup_str
=
"N/A"
prepared_results
.
append
(
{
"operation"
:
op_name
,
"time_ms"
:
time_ms
,
"time_str"
:
time_str
,
"speedup_str"
:
speedup_str
,
}
)
return
prepared_results
def
print_results
(
results_dict
,
num_tokens
,
hidden_dim
,
dtype
,
use_residual
,
quant_modes
,
input_size_mb
,
):
"""Print benchmark results in a formatted table."""
print
(
f
"
\n
{
'='
*
80
}
"
)
print
(
f
"Results: num_tokens=
{
num_tokens
}
, hidden_dim=
{
hidden_dim
}
"
f
"(input size:
{
input_size_mb
:.
2
f
}
MB)"
)
print
(
f
"dtype=
{
dtype
}
, residual=
{
'yes'
if
use_residual
else
'no'
}
, "
f
"quant_modes=
{
','
.
join
(
sorted
(
list
(
quant_modes
)))
}
"
)
print
(
f
"
{
'='
*
80
}
"
)
print
(
f
"
{
'Operation'
:
<
50
}
{
'Time (ms)'
:
<
12
}
{
'Speedup'
:
<
10
}
"
)
print
(
f
"
{
'-'
*
80
}
"
)
# Prepare results with speedup calculations
prepared_results
=
prepare_results_with_speedups
(
results_dict
)
for
result
in
prepared_results
:
if
result
[
"time_ms"
]
==
float
(
"inf"
):
time_display
=
result
[
"time_str"
]
else
:
time_display
=
f
"
{
result
[
'time_ms'
]:.
3
f
}
"
print
(
f
"
{
result
[
'operation'
]:
<
50
}
{
time_display
:
<
12
}
{
result
[
'speedup_str'
]:
<
10
}
"
)
def
format_results_markdown
(
all_results
:
list
[
dict
],
world_size
:
int
,
args
:
argparse
.
Namespace
)
->
str
:
"""Format all benchmark results as markdown."""
lines
:
list
[
str
]
=
[]
lines
.
append
(
"# FlashInfer Fused Collective Operations Benchmark Results"
)
lines
.
append
(
""
)
lines
.
append
(
f
"**World Size:**
{
world_size
}
"
)
lines
.
append
(
f
"**Hidden Dimension:**
{
args
.
hidden_dim
}
"
)
lines
.
append
(
f
"**Warmup Iterations:**
{
args
.
warmup
}
"
)
lines
.
append
(
f
"**Benchmark Trials:**
{
args
.
trials
}
"
)
modes
=
","
.
join
(
all_results
[
0
][
"quant_modes"
])
if
all_results
else
"N/A"
lines
.
append
(
f
"**Quantization Modes:**
{
modes
}
"
)
lines
.
append
(
""
)
lines
.
append
(
"---"
)
lines
.
append
(
""
)
for
entry
in
all_results
:
num_tokens
=
entry
[
"num_tokens"
]
dtype
=
entry
[
"dtype"
]
use_residual
=
entry
[
"use_residual"
]
results_dict
=
entry
[
"results"
]
input_size_mb
=
entry
[
"input_size_mb"
]
residual_str
=
"with residual"
if
use_residual
else
"no residual"
lines
.
append
(
f
"## Configuration: num_tokens=
{
num_tokens
}
, dtype=
{
dtype
}
,
{
residual_str
}
"
)
lines
.
append
(
f
"**Input Size:**
{
input_size_mb
:.
2
f
}
MB"
)
lines
.
append
(
""
)
prepared
=
prepare_results_with_speedups
(
results_dict
)
# Build DataFrame for markdown export
rows
=
[
{
"Operation"
:
r
[
"operation"
].
replace
(
"_"
,
" "
).
title
(),
"Time (ms)"
:
r
[
"time_str"
],
"Speedup"
:
r
[
"speedup_str"
],
}
for
r
in
prepared
]
df
=
pd
.
DataFrame
(
rows
)
if
df
.
empty
:
lines
.
append
(
"No results."
)
else
:
lines
.
append
(
df
.
to_markdown
(
index
=
False
))
lines
.
append
(
""
)
return
"
\n
"
.
join
(
lines
)
def
save_results_to_file
(
all_results
:
list
[
dict
],
world_size
:
int
,
args
:
argparse
.
Namespace
,
rank
:
int
):
"""Save benchmark results to markdown file (only on rank 0)."""
if
rank
!=
0
:
return
if
not
all_results
:
logger
.
warning
(
"No results to save"
)
return
output_path
=
args
.
output_file
try
:
markdown_content
=
format_results_markdown
(
all_results
,
world_size
,
args
)
with
open
(
output_path
,
"a"
)
as
f
:
f
.
write
(
markdown_content
)
except
Exception
as
e
:
logger
.
error
(
"Failed to save results to file: %s"
,
e
)
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Benchmark fused collective operations"
)
parser
.
add_argument
(
"--num-tokens"
,
type
=
int
,
nargs
=
"+"
,
default
=
[
128
,
512
,
1024
,
2048
],
help
=
"Numbers of tokens to test"
,
)
parser
.
add_argument
(
"--hidden-dim"
,
type
=
int
,
default
=
8192
,
help
=
"Hidden dimension size"
)
parser
.
add_argument
(
"--dtypes"
,
type
=
str
,
nargs
=
"+"
,
default
=
[
"bfloat16"
],
choices
=
[
"float16"
,
"bfloat16"
,
"float32"
],
help
=
"Data types to test"
,
)
parser
.
add_argument
(
"--no-residual"
,
action
=
"store_true"
,
help
=
"Skip residual connection tests"
,
)
parser
.
add_argument
(
"--quant-modes"
,
type
=
str
,
default
=
"none,fp8,fp4"
,
help
=
(
"Comma-separated quantization modes to run: none, fp8, fp4. "
"Default: none,fp8,fp4"
),
)
parser
.
add_argument
(
"--warmup"
,
type
=
int
,
default
=
5
,
help
=
"Number of warmup iterations"
)
parser
.
add_argument
(
"--trials"
,
type
=
int
,
default
=
20
,
help
=
"Number of benchmark trials"
)
parser
.
add_argument
(
"--output-file"
,
type
=
str
,
help
=
"""Output file path for markdown results
(default: benchmark_results_<timestamp>.md)
"""
,
)
parser
.
add_argument
(
"--no-oneshot"
,
action
=
"store_true"
,
help
=
"Skip oneshot benchmarks"
,
)
args
=
parser
.
parse_args
()
# Check if running with torchrun (required for collective operations)
if
"RANK"
not
in
os
.
environ
or
"WORLD_SIZE"
not
in
os
.
environ
:
raise
RuntimeError
(
"Must run with torchrun for distributed benchmarking. "
"Example: torchrun --nproc_per_node=2 benchmark_fused_collective.py"
)
# Initialize distributed environment
rank
=
int
(
os
.
environ
[
"RANK"
])
world_size
=
int
(
os
.
environ
[
"WORLD_SIZE"
])
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
init_distributed_environment
()
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
# Validate world size (must be > 1 for collective operations)
if
world_size
<=
1
:
raise
ValueError
(
"World size must be > 1 for collective operations benchmarking. "
f
"Current world size:
{
world_size
}
. Use torchrun with --nproc_per_node > 1."
)
# Parse quantization modes
valid_quant_modes
=
{
"none"
,
"fp8"
,
"fp4"
}
raw_modes
=
[
m
.
strip
().
lower
()
for
m
in
(
args
.
quant_modes
or
""
).
split
(
","
)
if
m
.
strip
()
]
quant_modes
=
set
(
raw_modes
)
if
raw_modes
else
{
"none"
,
"fp8"
,
"fp4"
}
invalid
=
sorted
(
list
(
quant_modes
-
valid_quant_modes
))
if
invalid
:
raise
ValueError
(
f
"Invalid --quant-modes entries:
{
','
.
join
(
invalid
)
}
. "
f
"Valid options are:
{
','
.
join
(
sorted
(
valid_quant_modes
))
}
."
)
if
rank
==
0
:
logger
.
info
(
"Running benchmark with world_size=%s, rank=%s"
,
world_size
,
rank
)
logger
.
info
(
"Quantization modes: %s"
,
","
.
join
(
sorted
(
list
(
quant_modes
))))
if
flashinfer_comm
is
not
None
:
logger
.
info
(
"FlashInfer available - will benchmark fused operations"
,
)
else
:
logger
.
info
(
"FlashInfer not available - only benchmarking standard operations"
)
# Convert dtype strings to torch dtypes
dtype_map
=
{
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
,
"float32"
:
torch
.
float32
,
}
dtypes
=
[
dtype_map
[
dt
]
for
dt
in
args
.
dtypes
]
# Test configurations
residual_options
=
[
True
]
if
not
args
.
no_residual
else
[
False
]
configs
=
list
(
itertools
.
product
(
args
.
num_tokens
,
dtypes
,
residual_options
))
# Setup FlashInfer workspace if available
ipc_handles
=
None
allreduce_params
=
None
if
flashinfer_comm
is
not
None
:
# Use the largest hidden dimension for workspace setup
max_num_token
=
_FI_MAX_SIZES
.
get
(
world_size
)
//
(
args
.
hidden_dim
*
world_size
*
2
)
ipc_handles
,
workspace_tensor
=
setup_flashinfer_workspace
(
world_size
,
rank
,
args
.
hidden_dim
,
max_num_token
)
if
workspace_tensor
is
not
None
:
allreduce_params
=
FlashInferFusedAllReduceParams
(
rank
=
rank
,
world_size
=
world_size
,
max_token_num
=
max_num_token
,
)
# Collect all results for markdown export
all_results
=
[]
try
:
# Run benchmarks
for
num_tokens
,
dtype
,
use_residual
in
configs
:
if
rank
==
0
:
logger
.
info
(
"
\n
Testing: num_tokens=%s, hidden_dim=%s, dtype=%s, residual=%s"
,
num_tokens
,
args
.
hidden_dim
,
dtype
,
use_residual
,
)
results
=
run_benchmarks
(
num_tokens
,
args
.
hidden_dim
,
dtype
,
use_residual
,
allreduce_params
,
quant_modes
=
quant_modes
,
no_oneshot
=
args
.
no_oneshot
,
)
# Store results for markdown export
if
rank
==
0
:
# Calculate input size in MB
input_size_mb
=
(
num_tokens
*
args
.
hidden_dim
*
torch
.
finfo
(
dtype
).
bits
)
/
(
8
*
1024
*
1024
)
all_results
.
append
(
{
"num_tokens"
:
num_tokens
,
"hidden_dim"
:
args
.
hidden_dim
,
"dtype"
:
str
(
dtype
).
replace
(
"torch."
,
""
),
"use_residual"
:
use_residual
,
"quant_modes"
:
sorted
(
list
(
quant_modes
)),
"input_size_mb"
:
input_size_mb
,
"results"
:
results
,
}
)
print_results
(
results
,
num_tokens
,
args
.
hidden_dim
,
dtype
,
use_residual
,
quant_modes
,
input_size_mb
,
)
# Save results to markdown file
if
args
.
output_file
and
rank
==
0
:
save_results_to_file
(
all_results
,
world_size
,
args
,
rank
)
finally
:
# Cleanup
if
ipc_handles
is
not
None
:
cleanup_flashinfer_workspace
(
ipc_handles
)
dist
.
barrier
()
if
__name__
==
"__main__"
:
main
()
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
View file @
41199996
...
@@ -13,11 +13,11 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
...
@@ -13,11 +13,11 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts
,
fused_experts
,
fused_topk
,
fused_topk
,
)
)
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
[
DEFAULT_MODELS
=
[
"
nm-testing
/Mixtral-8x7B-Instruct-v0.1"
,
"
mistralai
/Mixtral-8x7B-Instruct-v0.1"
,
"
nm-testing/d
eep
s
eek
v
2-
l
ite"
,
"
deepseek-ai/D
eep
S
eek
-V
2-
L
ite"
,
"ibm-granite/granite-3.0-1b-a400m"
,
"ibm-granite/granite-3.0-1b-a400m"
,
"ibm-granite/granite-3.0-3b-a800m"
,
"ibm-granite/granite-3.0-3b-a800m"
,
]
]
...
...
benchmarks/kernels/benchmark_layernorm.py
View file @
41199996
...
@@ -7,7 +7,8 @@ import torch
...
@@ -7,7 +7,8 @@ import torch
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
...
benchmarks/kernels/benchmark_lora.py
View file @
41199996
...
@@ -6,11 +6,12 @@ import copy
...
@@ -6,11 +6,12 @@ import copy
import
json
import
json
import
pickle
import
pickle
import
time
import
time
from
collections.abc
import
Callable
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
itertools
import
product
from
itertools
import
product
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Optional
from
typing
import
Any
import
torch
import
torch
import
torch.utils.benchmark
as
TBenchmark
import
torch.utils.benchmark
as
TBenchmark
...
@@ -18,13 +19,24 @@ from torch.utils.benchmark import Measurement as TMeasurement
...
@@ -18,13 +19,24 @@ from torch.utils.benchmark import Measurement as TMeasurement
from
utils
import
ArgPool
,
Bench
,
CudaGraphBenchParams
from
utils
import
ArgPool
,
Bench
,
CudaGraphBenchParams
from
weight_shapes
import
WEIGHT_SHAPES
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm.triton_utils
import
HAS_TRITON
from
vllm.lora.ops.triton_ops.utils
import
get_lora_op_configs
from
vllm.triton_utils
import
HAS_TRITON
,
triton
if
HAS_TRITON
:
if
HAS_TRITON
:
from
vllm.lora.ops.triton_ops
import
LoRAKernelMeta
,
lora_expand
,
lora_shrink
from
vllm.lora.ops.triton_ops
import
(
## added fused_moe_lora
LoRAKernelMeta
,
fused_moe_lora_expand
,
fused_moe_lora_shrink
,
lora_expand
,
lora_shrink
,
)
from
vllm.lora.ops.triton_ops.fused_moe_lora_op
import
(
_LORA_PTR_DICT
,
## added _LORA_PTR_DICT for fused_moe_lora
)
from
vllm.lora.ops.triton_ops.utils
import
_LORA_A_PTR_DICT
,
_LORA_B_PTR_DICT
from
vllm.lora.ops.triton_ops.utils
import
_LORA_A_PTR_DICT
,
_LORA_B_PTR_DICT
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.math_utils
import
round_up
DEFAULT_MODELS
=
list
(
WEIGHT_SHAPES
.
keys
())
DEFAULT_MODELS
=
list
(
WEIGHT_SHAPES
.
keys
())
DEFAULT_TP_SIZES
=
[
1
]
DEFAULT_TP_SIZES
=
[
1
]
...
@@ -58,6 +70,8 @@ DEFAULT_NUM_LORAS = [1, 2, 3, 4]
...
@@ -58,6 +70,8 @@ DEFAULT_NUM_LORAS = [1, 2, 3, 4]
DEFAULT_SORT_BY_LORA_IDS
=
[
False
,
True
]
DEFAULT_SORT_BY_LORA_IDS
=
[
False
,
True
]
DEFAULT_SEQ_LENGTHS
=
[
1
]
DEFAULT_SEQ_LENGTHS
=
[
1
]
DEFAULT_EXPAND_FN_ADD_INPUTS
=
[
True
,
False
]
DEFAULT_EXPAND_FN_ADD_INPUTS
=
[
True
,
False
]
DEFAULT_TOP_K_NUMS
=
[
1
]
# Added for MoE LoRA top_k
DEFAULT_NUM_EXPERTS
=
[
8
]
# Added for MoE LoRA num_experts
# Utilities
# Utilities
...
@@ -158,7 +172,7 @@ def ref_group_gemm(
...
@@ -158,7 +172,7 @@ def ref_group_gemm(
seq_lens_cpu
:
torch
.
Tensor
,
seq_lens_cpu
:
torch
.
Tensor
,
prompt_lora_mapping_cpu
:
torch
.
Tensor
,
prompt_lora_mapping_cpu
:
torch
.
Tensor
,
scaling
:
float
,
scaling
:
float
,
add_inputs
:
Optional
[
bool
]
,
add_inputs
:
bool
|
None
,
):
):
"""
"""
Torch group gemm reference implementation to test correctness of
Torch group gemm reference implementation to test correctness of
...
@@ -190,6 +204,11 @@ class OpType(Enum):
...
@@ -190,6 +204,11 @@ class OpType(Enum):
LORA_SHRINK
=
auto
()
LORA_SHRINK
=
auto
()
LORA_EXPAND
=
auto
()
LORA_EXPAND
=
auto
()
## Adding support for fused moe lora
FUSED_MOE_LORA_GATE_UP_SHRINK
=
auto
()
## Gate/Up projection variant with shrink
FUSED_MOE_LORA_GATE_UP_EXPAND
=
auto
()
## Gate/Up projection variant with expand
FUSED_MOE_LORA_DOWN_SHRINK
=
auto
()
## Down projection variant with shrink
FUSED_MOE_LORA_DOWN_EXPAND
=
auto
()
## Down projection variant with expand
@
staticmethod
@
staticmethod
def
from_str
(
s
:
str
)
->
"OpType"
:
def
from_str
(
s
:
str
)
->
"OpType"
:
...
@@ -197,6 +216,15 @@ class OpType(Enum):
...
@@ -197,6 +216,15 @@ class OpType(Enum):
return
OpType
.
LORA_SHRINK
return
OpType
.
LORA_SHRINK
if
s
.
lower
()
==
"lora_expand"
:
if
s
.
lower
()
==
"lora_expand"
:
return
OpType
.
LORA_EXPAND
return
OpType
.
LORA_EXPAND
# Adding support for fused moe lora, both in gate_up and down
if
s
.
lower
()
==
"fused_moe_lora_gate_up_shrink"
:
## Gate/Up variant with shrink
return
OpType
.
FUSED_MOE_LORA_GATE_UP_SHRINK
if
s
.
lower
()
==
"fused_moe_lora_gate_up_expand"
:
## Gate/Up variant with expand
return
OpType
.
FUSED_MOE_LORA_GATE_UP_EXPAND
if
s
.
lower
()
==
"fused_moe_lora_down_shrink"
:
## Down variant with shrink
return
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
if
s
.
lower
()
==
"fused_moe_lora_down_expand"
:
## Down variant with expand
return
OpType
.
FUSED_MOE_LORA_DOWN_EXPAND
raise
ValueError
(
f
"Unrecognized str
{
s
}
to convert to OpType"
)
raise
ValueError
(
f
"Unrecognized str
{
s
}
to convert to OpType"
)
def
is_shrink_fn
(
self
)
->
bool
:
def
is_shrink_fn
(
self
)
->
bool
:
...
@@ -205,19 +233,56 @@ class OpType(Enum):
...
@@ -205,19 +233,56 @@ class OpType(Enum):
def
is_expand_fn
(
self
)
->
bool
:
def
is_expand_fn
(
self
)
->
bool
:
return
self
in
[
OpType
.
LORA_EXPAND
]
return
self
in
[
OpType
.
LORA_EXPAND
]
def
is_fused_moe_lora_fn
(
self
)
->
bool
:
## adding for fused MoE LoRA
return
self
in
[
OpType
.
FUSED_MOE_LORA_GATE_UP_SHRINK
,
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
,
OpType
.
FUSED_MOE_LORA_GATE_UP_EXPAND
,
OpType
.
FUSED_MOE_LORA_DOWN_EXPAND
,
]
def
is_fused_moe_lora_gate_up_fn
(
self
,
)
->
bool
:
## adding for fused MoE LoRA Gate/Up
return
self
in
[
OpType
.
FUSED_MOE_LORA_GATE_UP_SHRINK
,
OpType
.
FUSED_MOE_LORA_GATE_UP_EXPAND
,
]
def
is_fused_moe_lora_down_fn
(
self
)
->
bool
:
## adding for fused MoE LoRA Down
return
self
in
[
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
,
OpType
.
FUSED_MOE_LORA_DOWN_EXPAND
,
]
def
is_fused_moe_lora_shrink_fn
(
self
)
->
bool
:
return
self
in
[
OpType
.
FUSED_MOE_LORA_GATE_UP_SHRINK
,
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
,
]
def
is_fused_moe_lora_expand_fn
(
self
)
->
bool
:
return
self
in
[
OpType
.
FUSED_MOE_LORA_GATE_UP_EXPAND
,
OpType
.
FUSED_MOE_LORA_DOWN_EXPAND
,
]
def
num_slices
(
self
)
->
list
[
int
]:
def
num_slices
(
self
)
->
list
[
int
]:
if
self
.
is_fused_moe_lora_gate_up_fn
():
return
[
2
]
elif
self
.
is_fused_moe_lora_down_fn
():
return
[
1
]
return
[
1
,
2
,
3
]
return
[
1
,
2
,
3
]
def
mkn
(
def
mkn
(
self
,
batch_size
:
int
,
seq_length
:
int
,
hidden_size
:
int
,
lora_rank
:
int
self
,
batch_size
:
int
,
seq_length
:
int
,
hidden_size
:
int
,
lora_rank
:
int
)
->
tuple
[
int
,
int
,
int
]:
)
->
tuple
[
int
,
int
,
int
]:
num_tokens
=
batch_size
*
seq_length
num_tokens
=
batch_size
*
seq_length
if
self
.
is_shrink_fn
():
if
self
.
is_shrink_fn
()
or
self
.
is_fused_moe_lora_fn
()
:
m
=
num_tokens
m
=
num_tokens
k
=
hidden_size
k
=
hidden_size
n
=
lora_rank
n
=
lora_rank
else
:
elif
self
.
is_expand_fn
():
assert
self
.
is_expand_fn
()
m
=
num_tokens
m
=
num_tokens
k
=
lora_rank
k
=
lora_rank
n
=
hidden_size
n
=
hidden_size
...
@@ -231,9 +296,36 @@ class OpType(Enum):
...
@@ -231,9 +296,36 @@ class OpType(Enum):
"""
"""
if
self
.
is_shrink_fn
():
if
self
.
is_shrink_fn
():
return
op_dtype
,
op_dtype
,
torch
.
float32
return
op_dtype
,
op_dtype
,
torch
.
float32
else
:
elif
self
.
is_expand_fn
():
assert
self
.
is_expand_fn
()
return
torch
.
float32
,
op_dtype
,
op_dtype
return
torch
.
float32
,
op_dtype
,
op_dtype
else
:
assert
self
.
is_fused_moe_lora_fn
()
return
op_dtype
,
op_dtype
,
op_dtype
def
matmul_shapes_fused_moe_lora
(
self
,
m
:
int
,
n
:
int
,
k
:
int
,
num_loras
:
int
,
num_slices
:
int
,
top_k_num
:
int
,
num_experts
:
int
,
)
->
tuple
[
tuple
[
int
],
tuple
[
int
],
tuple
[
int
],
tuple
[
int
]]:
if
self
.
is_fused_moe_lora_shrink_fn
():
input_shape
=
(
(
m
*
top_k_num
,
n
)
if
self
in
[
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
]
else
(
m
,
n
)
)
output_shape
=
(
num_slices
,
m
,
top_k_num
,
k
)
weight_shape
=
(
num_loras
,
num_experts
,
k
,
n
)
else
:
assert
self
.
is_fused_moe_lora_expand_fn
()
input_shape
=
(
num_slices
,
m
,
top_k_num
,
k
)
output_shape
=
(
m
,
top_k_num
,
n
*
num_slices
)
weight_shape
=
(
num_loras
,
num_experts
,
n
,
k
)
return
(
input_shape
,
weight_shape
,
output_shape
)
def
matmul_shapes
(
def
matmul_shapes
(
self
,
self
,
...
@@ -243,6 +335,8 @@ class OpType(Enum):
...
@@ -243,6 +335,8 @@ class OpType(Enum):
lora_rank
:
int
,
lora_rank
:
int
,
num_loras
:
int
,
num_loras
:
int
,
num_slices
:
int
,
num_slices
:
int
,
top_k_num
:
int
|
None
=
None
,
num_experts
:
int
|
None
=
None
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
"""
"""
Given num_slices, return the shapes of the A, B, and C matrices
Given num_slices, return the shapes of the A, B, and C matrices
...
@@ -257,6 +351,16 @@ class OpType(Enum):
...
@@ -257,6 +351,16 @@ class OpType(Enum):
if
self
in
[
OpType
.
LORA_EXPAND
]:
if
self
in
[
OpType
.
LORA_EXPAND
]:
# LoRA expand kernels support num_slices inherently in the kernel
# LoRA expand kernels support num_slices inherently in the kernel
return
((
num_slices
,
m
,
k
),
b_shape
,
(
m
,
n
*
num_slices
))
return
((
num_slices
,
m
,
k
),
b_shape
,
(
m
,
n
*
num_slices
))
if
self
.
is_fused_moe_lora_fn
():
return
self
.
matmul_shapes_fused_moe_lora
(
m
,
k
,
n
,
num_loras
,
num_slices
,
top_k_num
,
num_experts
,
)
raise
ValueError
(
f
"Unrecognized op_type
{
self
}
"
)
raise
ValueError
(
f
"Unrecognized op_type
{
self
}
"
)
def
bench_fn
(
self
)
->
Callable
:
def
bench_fn
(
self
)
->
Callable
:
...
@@ -264,6 +368,16 @@ class OpType(Enum):
...
@@ -264,6 +368,16 @@ class OpType(Enum):
return
lora_shrink
return
lora_shrink
if
self
==
OpType
.
LORA_EXPAND
:
if
self
==
OpType
.
LORA_EXPAND
:
return
lora_expand
return
lora_expand
if
self
in
[
OpType
.
FUSED_MOE_LORA_GATE_UP_SHRINK
,
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
,
]:
return
fused_moe_lora_shrink
if
self
in
[
OpType
.
FUSED_MOE_LORA_GATE_UP_EXPAND
,
OpType
.
FUSED_MOE_LORA_DOWN_EXPAND
,
]:
return
fused_moe_lora_expand
raise
ValueError
(
f
"Unrecognized optype
{
self
}
"
)
raise
ValueError
(
f
"Unrecognized optype
{
self
}
"
)
...
@@ -316,8 +430,10 @@ class BenchmarkContext:
...
@@ -316,8 +430,10 @@ class BenchmarkContext:
lora_rank
:
int
lora_rank
:
int
sort_by_lora_id
:
bool
sort_by_lora_id
:
bool
dtype
:
torch
.
dtype
dtype
:
torch
.
dtype
seq_length
:
Optional
[
int
]
=
None
seq_length
:
int
|
None
=
None
num_slices
:
Optional
[
int
]
=
None
# num_slices for slice based ops
num_experts
:
int
|
None
=
None
# num_experts for MoE based ops
top_k_num
:
int
|
None
=
None
# top_k for MoE based ops
num_slices
:
int
|
None
=
None
# num_slices for slice based ops
def
with_seq_length
(
self
,
seq_length
:
int
)
->
"BenchmarkContext"
:
def
with_seq_length
(
self
,
seq_length
:
int
)
->
"BenchmarkContext"
:
ctx
=
copy
.
copy
(
self
)
ctx
=
copy
.
copy
(
self
)
...
@@ -372,6 +488,11 @@ class BenchmarkTensors:
...
@@ -372,6 +488,11 @@ class BenchmarkTensors:
f
"
{
dtype_to_str
(
self
.
output
.
dtype
)
}
"
f
"
{
dtype_to_str
(
self
.
output
.
dtype
)
}
"
)
)
def
get_num_tokens
(
self
,
size
:
int
,
top_k_num
:
int
,
op_type
:
OpType
):
return
(
size
*
top_k_num
if
op_type
in
[
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
]
else
size
)
@
staticmethod
@
staticmethod
def
make
(
def
make
(
ctx
:
BenchmarkContext
,
op_type
:
OpType
,
device
:
str
=
"cuda"
ctx
:
BenchmarkContext
,
op_type
:
OpType
,
device
:
str
=
"cuda"
...
@@ -384,6 +505,8 @@ class BenchmarkTensors:
...
@@ -384,6 +505,8 @@ class BenchmarkTensors:
ctx
.
lora_rank
,
ctx
.
lora_rank
,
ctx
.
num_loras
,
ctx
.
num_loras
,
ctx
.
num_slices
,
ctx
.
num_slices
,
ctx
.
top_k_num
,
ctx
.
num_experts
,
)
)
a_type
,
b_type
,
c_type
=
op_type
.
matmul_dtypes
(
ctx
.
dtype
)
a_type
,
b_type
,
c_type
=
op_type
.
matmul_dtypes
(
ctx
.
dtype
)
input_tensor
,
lora_weights
,
output_tensor
=
make_rand_tensors
(
input_tensor
,
lora_weights
,
output_tensor
=
make_rand_tensors
(
...
@@ -431,17 +554,27 @@ class BenchmarkTensors:
...
@@ -431,17 +554,27 @@ class BenchmarkTensors:
prompt_lora_indices_tensor
,
prompt_lora_indices_tensor
,
)
)
def
sanity_check
(
self
)
->
None
:
def
sanity_check
(
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
)
->
None
:
"""
"""
Fails asserts when non-conformality is detected.
Fails asserts when non-conformality is detected.
"""
"""
num_tokens
=
self
.
input
.
shape
[
-
2
]
num_tokens
=
(
self
.
input
.
shape
[
1
]
if
op_type
.
is_fused_moe_lora_expand_fn
()
else
self
.
input
.
shape
[
-
2
]
)
# check metadata tensors
# check metadata tensors
assert
torch
.
sum
(
self
.
seq_lens
)
==
num_tokens
## In down shrink case, each token is repeated top_k_num times
assert
num_tokens
==
self
.
get_num_tokens
(
torch
.
sum
(
self
.
seq_lens
),
ctx
.
top_k_num
,
op_type
),
f
"Expected
{
num_tokens
}
tokens, but got
{
torch
.
sum
(
self
.
seq_lens
)
}
"
num_seqs
=
self
.
seq_lens
.
shape
[
0
]
num_seqs
=
self
.
seq_lens
.
shape
[
0
]
# assert self.seq_start_loc.shape[0] == num_seqs
# assert self.seq_start_loc.shape[0] == num_seqs
## In down shrink case, each prompt corresponds to top_k_num sequences
assert
self
.
prompt_lora_mapping
.
shape
[
0
]
==
num_seqs
assert
self
.
prompt_lora_mapping
.
shape
[
0
]
==
num_seqs
assert
self
.
lora_kernel_meta
.
token_lora_mapping
.
shape
[
0
]
==
num_tokens
assert
self
.
get_num_tokens
(
self
.
lora_kernel_meta
.
token_lora_mapping
.
shape
[
0
],
ctx
.
top_k_num
,
op_type
)
def
to_device
(
self
,
device
:
str
):
def
to_device
(
self
,
device
:
str
):
"""
"""
...
@@ -470,21 +603,111 @@ class BenchmarkTensors:
...
@@ -470,21 +603,111 @@ class BenchmarkTensors:
to_device
(
field
)
if
field_name
!=
"no_lora_flag_cpu"
else
field
,
to_device
(
field
)
if
field_name
!=
"no_lora_flag_cpu"
else
field
,
)
)
def
metadata
(
self
)
->
tuple
[
int
,
int
,
int
]:
def
metadata
(
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
)
->
tuple
[
int
,
int
,
int
]:
"""
"""
Return num_seqs, num_tokens and max_seq_len
Return num_seqs, num_tokens and max_seq_len
"""
"""
num_seqs
=
self
.
seq_lens
.
shape
[
0
]
num_seqs
=
self
.
seq_lens
.
shape
[
0
]
num_tokens
=
self
.
lora_kernel_meta
.
token_lora_mapping
.
shape
[
0
]
num_tokens
=
self
.
get_num_tokens
(
self
.
lora_kernel_meta
.
token_lora_mapping
.
shape
[
0
],
ctx
.
top_k_num
,
op_type
)
max_seq_len
=
torch
.
max
(
self
.
seq_lens
).
item
()
max_seq_len
=
torch
.
max
(
self
.
seq_lens
).
item
()
num_slices
=
len
(
self
.
lora_weights_lst
)
num_slices
=
len
(
self
.
lora_weights_lst
)
return
num_seqs
,
num_tokens
,
max_seq_len
,
num_slices
return
num_seqs
,
num_tokens
,
max_seq_len
,
num_slices
def
as_lora_shrink_kwargs
(
self
)
->
dict
[
str
,
Any
]:
def
fused_moe_lora_data_prepare
(
self
.
sanity_check
()
self
,
block_size
:
int
,
token_lora_mapping
:
torch
.
Tensor
,
ctx
:
BenchmarkContext
,
):
def
moe_lora_align_block_size
(
topk_ids
:
torch
.
Tensor
,
token_lora_mapping
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
,
max_loras
:
int
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
pad_sorted_ids
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Aligns tokens and experts into block-sized chunks for LoRA-based
mixture-of-experts (MoE) execution.
"""
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
if
pad_sorted_ids
:
max_num_tokens_padded
=
round_up
(
max_num_tokens_padded
,
block_size
)
sorted_ids
=
torch
.
empty
(
(
max_loras
*
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
,
)
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
# Expert ids must be set default to -1 to prevent a blank block
expert_ids
=
torch
.
empty
(
(
max_loras
*
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
,
)
num_tokens_post_pad
=
torch
.
empty
(
(
max_loras
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
ops
.
moe_lora_align_block_size
(
topk_ids
,
token_lora_mapping
,
num_experts
,
block_size
,
max_loras
,
max_num_tokens_padded
,
max_num_m_blocks
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
)
if
expert_map
is
not
None
:
expert_ids
=
expert_map
[
expert_ids
]
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
num_tokens
=
ctx
.
batch_size
curr_topk_ids
=
torch
.
randint
(
0
,
ctx
.
num_experts
,
(
num_tokens
,
ctx
.
top_k_num
),
device
=
"cuda"
,
dtype
=
torch
.
int32
,
)
topk_weights
=
torch
.
randint
(
0
,
ctx
.
num_experts
,
(
num_tokens
,
ctx
.
top_k_num
),
device
=
"cuda"
,
dtype
=
torch
.
int32
,
)
(
sorted_token_ids_lora
,
expert_ids_lora
,
num_tokens_post_padded_lora
)
=
(
moe_lora_align_block_size
(
topk_ids
=
curr_topk_ids
,
token_lora_mapping
=
token_lora_mapping
,
block_size
=
block_size
,
num_experts
=
ctx
.
num_experts
,
max_loras
=
ctx
.
num_loras
,
)
)
sorted_token_ids
=
sorted_token_ids_lora
.
view
(
ctx
.
num_loras
,
-
1
)
expert_ids
=
expert_ids_lora
.
view
(
ctx
.
num_loras
,
-
1
)
num_tokens_post_padded
=
num_tokens_post_padded_lora
return
(
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
)
def
as_lora_shrink_kwargs
(
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
)
->
dict
[
str
,
Any
]:
self
.
sanity_check
(
ctx
,
op_type
)
self
.
to_device
(
self
.
input
.
device
)
self
.
to_device
(
self
.
input
.
device
)
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
()
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
(
ctx
,
op_type
)
# Sanity check matrix shapes.
# Sanity check matrix shapes.
i_shape
,
lw_shape
,
o_shape
=
(
i_shape
,
lw_shape
,
o_shape
=
(
...
@@ -519,11 +742,13 @@ class BenchmarkTensors:
...
@@ -519,11 +742,13 @@ class BenchmarkTensors:
"no_lora_flag_cpu"
:
self
.
lora_kernel_meta
.
no_lora_flag_cpu
,
"no_lora_flag_cpu"
:
self
.
lora_kernel_meta
.
no_lora_flag_cpu
,
}
}
def
as_lora_expand_kwargs
(
self
,
add_inputs
:
bool
)
->
dict
[
str
,
Any
]:
def
as_lora_expand_kwargs
(
self
.
sanity_check
()
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
,
add_inputs
:
bool
)
->
dict
[
str
,
Any
]:
self
.
sanity_check
(
ctx
,
op_type
)
self
.
to_device
(
self
.
input
.
device
)
self
.
to_device
(
self
.
input
.
device
)
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
()
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
(
ctx
,
op_type
)
# Sanity check matrix shapes.
# Sanity check matrix shapes.
i_shape
,
lw_shape
,
o_shape
=
(
i_shape
,
lw_shape
,
o_shape
=
(
...
@@ -560,22 +785,177 @@ class BenchmarkTensors:
...
@@ -560,22 +785,177 @@ class BenchmarkTensors:
"no_lora_flag_cpu"
:
self
.
lora_kernel_meta
.
no_lora_flag_cpu
,
"no_lora_flag_cpu"
:
self
.
lora_kernel_meta
.
no_lora_flag_cpu
,
}
}
def
as_fused_moe_lora_shrink_kwargs
(
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
)
->
dict
[
str
,
Any
]:
self
.
sanity_check
(
ctx
,
op_type
)
self
.
to_device
(
self
.
input
.
device
)
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
(
ctx
,
op_type
)
# Sanity check matrix shapes.
i_shape
,
lw_shape
,
o_shape
=
(
self
.
input
.
shape
,
self
.
lora_weights_lst
[
0
].
shape
,
self
.
output
.
shape
,
)
# Expected input shape : [num_tokens, hidden_size] for gate_up
# Expected input shape : [top_k_num * num_tokens, hidden_size] for down
assert
len
(
i_shape
)
==
2
assert
i_shape
[
0
]
==
num_tokens
hidden_size
=
i_shape
[
1
]
# Expected lora weight shape [max_lora, num_experts, lora_rank, hidden_size]
assert
len
(
lw_shape
)
==
4
assert
lw_shape
[
-
1
]
==
hidden_size
lora_rank
=
lw_shape
[
-
2
]
# Expected output shape : [num_slices, num_tokens, top_k_num, lora_rank]
assert
len
(
o_shape
)
==
4
assert
(
o_shape
==
(
num_slices
,
num_tokens
//
ctx
.
top_k_num
,
ctx
.
top_k_num
,
lora_rank
)
if
op_type
in
[
OpType
.
FUSED_MOE_LORA_DOWN_SHRINK
]
else
o_shape
==
(
num_slices
,
num_tokens
,
ctx
.
top_k_num
,
lora_rank
)
)
kernel_config
=
get_lora_op_configs
(
op_type
.
name
.
lower
(),
max_loras
=
lw_shape
[
0
],
batch
=
num_tokens
,
hidden_size
=
hidden_size
,
rank
=
lora_rank
,
num_slices
=
num_slices
,
add_inputs
=
False
,
)
(
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
)
=
(
self
.
fused_moe_lora_data_prepare
(
block_size
=
kernel_config
[
"BLOCK_SIZE_M"
],
token_lora_mapping
=
self
.
lora_kernel_meta
.
token_lora_mapping
,
ctx
=
ctx
,
)
)
return
{
"qcurr_hidden_states"
:
self
.
input
,
"lora_a_stacked"
:
self
.
lora_weights_lst
,
"a_intermediate_cache1"
:
self
.
output
,
"topk_weights"
:
topk_weights
,
"sorted_token_ids"
:
sorted_token_ids
,
"expert_ids"
:
expert_ids
,
"num_tokens_post_padded"
:
num_tokens_post_padded
,
"top_k_num"
:
ctx
.
top_k_num
,
"device"
:
self
.
input
.
device
,
"N"
:
lora_rank
,
"M"
:
topk_weights
.
shape
[
0
],
"EM"
:
sorted_token_ids
.
shape
[
1
],
"K"
:
self
.
input
.
shape
[
1
],
"num_tokens"
:
num_tokens
,
"num_experts"
:
ctx
.
num_experts
,
"num_slices"
:
num_slices
,
"shrink_block_size_m"
:
kernel_config
[
"BLOCK_SIZE_M"
],
"shrink_block_size_n"
:
kernel_config
[
"BLOCK_SIZE_N"
],
"shrink_block_size_k"
:
kernel_config
[
"BLOCK_SIZE_K"
],
"shrink_group_size_m"
:
kernel_config
[
"GROUP_SIZE_M"
],
"shrink_num_warps"
:
kernel_config
[
"NUM_WARPS"
],
"shrink_num_stages"
:
kernel_config
[
"NUM_STAGES"
],
"shrink_split_k"
:
kernel_config
.
get
(
"SPLIT_K"
,
1
),
"mul_routed_weight"
:
op_type
.
is_fused_moe_lora_down_fn
(),
}
def
as_fused_moe_lora_expand_kwargs
(
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
)
->
dict
[
str
,
Any
]:
self
.
sanity_check
(
ctx
,
op_type
)
self
.
to_device
(
self
.
input
.
device
)
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
(
ctx
,
op_type
)
# Sanity check matrix shapes.
i_shape
,
lw_shape
,
o_shape
=
(
self
.
input
.
shape
,
self
.
lora_weights_lst
[
0
].
shape
,
self
.
output
.
shape
,
)
# Expected input shape : [num_slices, num_tokens, top_k_num, lora_rank]
assert
len
(
i_shape
)
==
4
assert
i_shape
[
0
]
==
num_slices
assert
i_shape
[
1
]
==
num_tokens
lora_rank
=
i_shape
[
-
1
]
# Expected lora weight shape : [num_loras, num_experts, hidden_size, lora_rank]
assert
len
(
lw_shape
)
==
4
assert
lw_shape
[
-
1
]
==
lora_rank
hidden_size
=
lw_shape
[
-
2
]
# Expected output shape : [num_tokens, top_k_num, hidden_size * num_slices]
assert
len
(
o_shape
)
==
3
assert
o_shape
==
(
num_tokens
,
ctx
.
top_k_num
,
hidden_size
*
num_slices
)
kernel_config
=
get_lora_op_configs
(
op_type
.
name
.
lower
(),
max_loras
=
lw_shape
[
0
],
batch
=
num_tokens
,
hidden_size
=
hidden_size
,
rank
=
lora_rank
,
num_slices
=
num_slices
,
add_inputs
=
False
,
)
(
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
)
=
(
self
.
fused_moe_lora_data_prepare
(
block_size
=
kernel_config
[
"BLOCK_SIZE_M"
],
token_lora_mapping
=
self
.
lora_kernel_meta
.
token_lora_mapping
,
ctx
=
ctx
,
)
)
return
{
"a_intermediate_cache1"
:
self
.
input
,
"lora_b_stacked"
:
self
.
lora_weights_lst
,
"output"
:
self
.
output
,
"topk_weights"
:
topk_weights
,
"sorted_token_ids"
:
sorted_token_ids
,
"expert_ids"
:
expert_ids
,
"num_tokens_post_padded"
:
num_tokens_post_padded
,
"top_k_num"
:
ctx
.
top_k_num
,
"device"
:
self
.
input
.
device
,
"N"
:
lora_rank
,
"M"
:
topk_weights
.
shape
[
0
],
"EM"
:
sorted_token_ids
.
shape
[
1
],
"K"
:
self
.
input
.
shape
[
1
],
"num_tokens"
:
num_tokens
,
"num_experts"
:
ctx
.
num_experts
,
"num_slices"
:
num_slices
,
"max_lora_rank"
:
lora_rank
,
"w1_output_dim_size"
:
lw_shape
[
2
],
"expand_block_size_m"
:
kernel_config
[
"BLOCK_SIZE_M"
],
"expand_block_size_n"
:
kernel_config
[
"BLOCK_SIZE_N"
],
"expand_block_size_k"
:
kernel_config
[
"BLOCK_SIZE_K"
],
"expand_group_size_m"
:
kernel_config
[
"GROUP_SIZE_M"
],
"expand_num_warps"
:
kernel_config
[
"NUM_WARPS"
],
"expand_num_stages"
:
kernel_config
[
"NUM_STAGES"
],
"expand_split_k"
:
kernel_config
.
get
(
"SPLIT_K"
,
1
),
"mul_routed_weight"
:
op_type
.
is_fused_moe_lora_down_fn
(),
}
def
bench_fn_kwargs
(
def
bench_fn_kwargs
(
self
,
op_type
:
OpType
,
add_inputs
:
Optional
[
bool
]
=
None
self
,
ctx
:
BenchmarkContext
,
op_type
:
OpType
,
add_inputs
:
bool
|
None
=
None
)
->
dict
[
str
,
Any
]:
)
->
dict
[
str
,
Any
]:
if
op_type
.
is_shrink_fn
():
if
op_type
.
is_shrink_fn
()
or
op_type
.
is_fused_moe_lora_fn
()
:
assert
add_inputs
is
None
assert
add_inputs
is
None
else
:
else
:
assert
add_inputs
is
not
None
assert
add_inputs
is
not
None
if
op_type
==
OpType
.
LORA_SHRINK
:
if
op_type
==
OpType
.
LORA_SHRINK
:
return
self
.
as_lora_shrink_kwargs
()
return
self
.
as_lora_shrink_kwargs
(
ctx
,
op_type
)
if
op_type
==
OpType
.
LORA_EXPAND
:
if
op_type
==
OpType
.
LORA_EXPAND
:
return
self
.
as_lora_expand_kwargs
(
add_inputs
)
return
self
.
as_lora_expand_kwargs
(
ctx
,
op_type
,
add_inputs
)
if
op_type
.
is_fused_moe_lora_shrink_fn
():
return
self
.
as_fused_moe_lora_shrink_kwargs
(
ctx
,
op_type
)
if
op_type
.
is_fused_moe_lora_expand_fn
():
return
self
.
as_fused_moe_lora_expand_kwargs
(
ctx
,
op_type
)
raise
ValueError
(
f
"Unrecognized optype
{
self
}
"
)
raise
ValueError
(
f
"Unrecognized optype
{
self
}
"
)
def
test_correctness
(
def
test_correctness
(
self
,
op_type
:
OpType
,
expand_fn_add_inputs
:
Optional
[
bool
]
self
,
op_type
:
OpType
,
expand_fn_add_inputs
:
bool
|
None
)
->
bool
:
)
->
bool
:
"""
"""
Test correctness of op_type implementation against a grouped gemm
Test correctness of op_type implementation against a grouped gemm
...
@@ -611,12 +991,12 @@ def bench_optype(
...
@@ -611,12 +991,12 @@ def bench_optype(
ctx
:
BenchmarkContext
,
ctx
:
BenchmarkContext
,
arg_pool_size
:
int
,
arg_pool_size
:
int
,
op_type
:
OpType
,
op_type
:
OpType
,
cuda_graph_nops
:
Optional
[
int
]
=
None
,
cuda_graph_nops
:
int
|
None
=
None
,
expand_fn_add_inputs
:
Optional
[
bool
]
=
None
,
expand_fn_add_inputs
:
bool
|
None
=
None
,
test_correctness
:
bool
=
False
,
test_correctness
:
bool
=
False
,
)
->
TMeasurement
:
)
->
TMeasurement
:
assert
arg_pool_size
>=
1
assert
arg_pool_size
>=
1
if
op_type
.
is_shrink_fn
():
if
op_type
.
is_shrink_fn
()
or
op_type
.
is_fused_moe_lora_fn
()
:
assert
expand_fn_add_inputs
is
None
assert
expand_fn_add_inputs
is
None
else
:
else
:
assert
expand_fn_add_inputs
is
not
None
assert
expand_fn_add_inputs
is
not
None
...
@@ -626,23 +1006,30 @@ def bench_optype(
...
@@ -626,23 +1006,30 @@ def bench_optype(
BenchmarkTensors
.
make
(
ctx
,
op_type
)
for
_
in
range
(
arg_pool_size
)
BenchmarkTensors
.
make
(
ctx
,
op_type
)
for
_
in
range
(
arg_pool_size
)
]
]
for
bt
in
bench_tensors
:
for
bt
in
bench_tensors
:
bt
.
sanity_check
()
bt
.
sanity_check
(
ctx
,
op_type
)
# Test correctness of our implementation.
# Test correctness of our implementation.
if
test_correctness
:
if
test_correctness
:
assert
op_type
in
[
OpType
.
LORA_SHRINK
,
OpType
.
LORA_EXPAND
],
(
f
"Correctness testing is not supported for
{
op_type
.
name
}
."
)
assert
all
(
assert
all
(
[
bt
.
test_correctness
(
op_type
,
expand_fn_add_inputs
)
for
bt
in
bench_tensors
]
[
bt
.
test_correctness
(
ctx
,
op_type
,
expand_fn_add_inputs
)
for
bt
in
bench_tensors
]
)
)
# BenchmarkTensors -> dict (kwargs)
# BenchmarkTensors -> dict (kwargs)
kwargs_list
=
[
kwargs_list
=
[
bt
.
bench_fn_kwargs
(
op_type
,
add_inputs
=
expand_fn_add_inputs
)
bt
.
bench_fn_kwargs
(
ctx
,
op_type
,
add_inputs
=
expand_fn_add_inputs
)
for
bt
in
bench_tensors
for
bt
in
bench_tensors
]
]
# Clear LoRA optimization hash-maps.
# Clear LoRA optimization hash-maps.
_LORA_A_PTR_DICT
.
clear
()
_LORA_A_PTR_DICT
.
clear
()
_LORA_B_PTR_DICT
.
clear
()
_LORA_B_PTR_DICT
.
clear
()
_LORA_PTR_DICT
.
clear
()
# Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up
# Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up
for
kwargs
in
kwargs_list
:
for
kwargs
in
kwargs_list
:
op_type
.
bench_fn
()(
**
kwargs
)
op_type
.
bench_fn
()(
**
kwargs
)
...
@@ -679,7 +1066,7 @@ def bench_torch_mm(
...
@@ -679,7 +1066,7 @@ def bench_torch_mm(
ctx
:
BenchmarkContext
,
ctx
:
BenchmarkContext
,
arg_pool_size
:
int
,
arg_pool_size
:
int
,
op_type
:
OpType
,
op_type
:
OpType
,
cuda_graph_nops
:
Optional
[
int
]
=
None
,
cuda_graph_nops
:
int
|
None
=
None
,
)
->
TMeasurement
:
)
->
TMeasurement
:
"""
"""
Benchmark basic torch.mm as a roofline.
Benchmark basic torch.mm as a roofline.
...
@@ -744,7 +1131,7 @@ def use_cuda_graph_recommendation() -> str:
...
@@ -744,7 +1131,7 @@ def use_cuda_graph_recommendation() -> str:
"""
"""
def
print_timers
(
timers
:
list
[
TMeasurement
],
args
:
Optional
[
argparse
.
Namespace
]
=
None
):
def
print_timers
(
timers
:
list
[
TMeasurement
],
args
:
argparse
.
Namespace
|
None
=
None
):
compare
=
TBenchmark
.
Compare
(
timers
)
compare
=
TBenchmark
.
Compare
(
timers
)
compare
.
print
()
compare
.
print
()
...
@@ -792,7 +1179,9 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
...
@@ -792,7 +1179,9 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
# Benchmark bench_op
# Benchmark bench_op
expand_fn_add_inputs
=
(
expand_fn_add_inputs
=
(
[
None
]
if
bench_op
.
is_shrink_fn
()
else
args
.
expand_fn_add_inputs
[
None
]
if
bench_op
.
is_shrink_fn
()
or
bench_op
.
is_fused_moe_lora_fn
()
else
args
.
expand_fn_add_inputs
)
)
for
add_input_arg
in
expand_fn_add_inputs
:
for
add_input_arg
in
expand_fn_add_inputs
:
seq_len_timers
.
append
(
seq_len_timers
.
append
(
...
@@ -830,12 +1219,22 @@ def as_benchmark_contexts(
...
@@ -830,12 +1219,22 @@ def as_benchmark_contexts(
hidden_sizes
:
list
[
int
],
lora_ranks
:
list
[
int
],
args
:
argparse
.
Namespace
hidden_sizes
:
list
[
int
],
lora_ranks
:
list
[
int
],
args
:
argparse
.
Namespace
)
->
list
[
BenchmarkContext
]:
)
->
list
[
BenchmarkContext
]:
ctxs
:
list
[
BenchmarkContext
]
=
[]
ctxs
:
list
[
BenchmarkContext
]
=
[]
for
batch_size
,
hidden_size
,
lora_rank
,
num_loras
,
sort_by_lora_id
in
product
(
# noqa
for
(
batch_size
,
hidden_size
,
lora_rank
,
num_loras
,
sort_by_lora_id
,
top_k_num
,
num_experts
,
)
in
product
(
# noqa
args
.
batch_sizes
,
args
.
batch_sizes
,
list
(
hidden_sizes
),
list
(
hidden_sizes
),
lora_ranks
,
lora_ranks
,
args
.
num_loras
,
args
.
num_loras
,
args
.
sort_by_lora_id
,
args
.
sort_by_lora_id
,
args
.
top_k_nums
,
args
.
num_experts
,
):
):
ctxs
.
append
(
ctxs
.
append
(
BenchmarkContext
(
BenchmarkContext
(
...
@@ -850,6 +1249,8 @@ def as_benchmark_contexts(
...
@@ -850,6 +1249,8 @@ def as_benchmark_contexts(
seq_length
=
None
,
seq_length
=
None
,
sort_by_lora_id
=
sort_by_lora_id
,
sort_by_lora_id
=
sort_by_lora_id
,
dtype
=
args
.
dtype
,
dtype
=
args
.
dtype
,
top_k_num
=
top_k_num
,
num_experts
=
num_experts
,
# To be filled based on the OpType to benchmark
# To be filled based on the OpType to benchmark
num_slices
=
None
,
num_slices
=
None
,
)
)
...
@@ -1011,6 +1412,22 @@ if __name__ == "__main__":
...
@@ -1011,6 +1412,22 @@ if __name__ == "__main__":
),
),
)
)
p
.
add_argument
(
"--top-k-nums"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_TOP_K_NUMS
,
help
=
"Top-K values for MoE LoRA operations"
,
)
p
.
add_argument
(
"--num-experts"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_NUM_EXPERTS
,
help
=
"Number of experts for MoE LoRA operations"
,
)
parser
=
FlexibleArgumentParser
(
parser
=
FlexibleArgumentParser
(
description
=
f
"""
description
=
f
"""
Benchmark LoRA kernels:
Benchmark LoRA kernels:
...
...
benchmarks/kernels/benchmark_machete.py
View file @
41199996
...
@@ -8,10 +8,9 @@ import math
...
@@ -8,10 +8,9 @@ import math
import
os
import
os
import
pickle
as
pkl
import
pickle
as
pkl
import
time
import
time
from
collections.abc
import
Iterable
from
collections.abc
import
Callable
,
Iterable
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
itertools
import
product
from
itertools
import
product
from
typing
import
Callable
,
Optional
import
pandas
as
pd
import
pandas
as
pd
import
torch
import
torch
...
@@ -34,7 +33,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
...
@@ -34,7 +33,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
quantize_weights
,
quantize_weights
,
)
)
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
[
"meta-llama/Llama-3-8b"
,
"meta-llama/Llama-2-70b-hf"
]
DEFAULT_MODELS
=
[
"meta-llama/Llama-3-8b"
,
"meta-llama/Llama-2-70b-hf"
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
]
...
@@ -63,23 +62,23 @@ class BenchmarkTensors:
...
@@ -63,23 +62,23 @@ class BenchmarkTensors:
a
:
torch
.
Tensor
a
:
torch
.
Tensor
w_q
:
torch
.
Tensor
w_q
:
torch
.
Tensor
group_size
:
Optional
[
int
]
group_size
:
int
|
None
wtype
:
ScalarType
wtype
:
ScalarType
w_g_s
:
torch
.
Tensor
w_g_s
:
torch
.
Tensor
w_g_zp
:
Optional
[
torch
.
Tensor
]
w_g_zp
:
torch
.
Tensor
|
None
w_ch_s
:
Optional
[
torch
.
Tensor
]
w_ch_s
:
torch
.
Tensor
|
None
w_tok_s
:
Optional
[
torch
.
Tensor
]
w_tok_s
:
torch
.
Tensor
|
None
@
dataclass
@
dataclass
class
TypeConfig
:
class
TypeConfig
:
act_type
:
torch
.
dtype
act_type
:
torch
.
dtype
weight_type
:
ScalarType
weight_type
:
ScalarType
output_type
:
Optional
[
torch
.
dtype
]
output_type
:
torch
.
dtype
|
None
group_scale_type
:
Optional
[
torch
.
dtype
]
group_scale_type
:
torch
.
dtype
|
None
group_zero_type
:
Optional
[
torch
.
dtype
]
group_zero_type
:
torch
.
dtype
|
None
channel_scale_type
:
Optional
[
torch
.
dtype
]
channel_scale_type
:
torch
.
dtype
|
None
token_scale_type
:
Optional
[
torch
.
dtype
]
token_scale_type
:
torch
.
dtype
|
None
def
rand_data
(
shape
,
dtype
=
torch
.
float16
,
scale
=
1
):
def
rand_data
(
shape
,
dtype
=
torch
.
float16
,
scale
=
1
):
...
@@ -93,8 +92,8 @@ def quantize_and_pack(
...
@@ -93,8 +92,8 @@ def quantize_and_pack(
atype
:
torch
.
dtype
,
atype
:
torch
.
dtype
,
w
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
wtype
:
ScalarType
,
wtype
:
ScalarType
,
stype
:
Optional
[
torch
.
dtype
]
,
stype
:
torch
.
dtype
|
None
,
group_size
:
Optional
[
int
]
,
group_size
:
int
|
None
,
zero_points
:
bool
=
False
,
zero_points
:
bool
=
False
,
):
):
assert
wtype
.
is_integer
(),
"TODO: support floating point weights"
assert
wtype
.
is_integer
(),
"TODO: support floating point weights"
...
@@ -113,7 +112,7 @@ def quantize_and_pack(
...
@@ -113,7 +112,7 @@ def quantize_and_pack(
def
create_bench_tensors
(
def
create_bench_tensors
(
shape
:
tuple
[
int
,
int
,
int
],
types
:
TypeConfig
,
group_size
:
Optional
[
int
]
shape
:
tuple
[
int
,
int
,
int
],
types
:
TypeConfig
,
group_size
:
int
|
None
)
->
list
[
BenchmarkTensors
]:
)
->
list
[
BenchmarkTensors
]:
m
,
n
,
k
=
shape
m
,
n
,
k
=
shape
...
@@ -238,6 +237,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
...
@@ -238,6 +237,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
b_q_weight
=
w_q
,
b_q_weight
=
w_q
,
b_bias
=
None
,
b_bias
=
None
,
b_scales
=
w_s
,
b_scales
=
w_s
,
a_scales
=
None
,
global_scale
=
None
,
global_scale
=
None
,
b_zeros
=
w_zp
,
b_zeros
=
w_zp
,
g_idx
=
g_idx
,
g_idx
=
g_idx
,
...
@@ -331,8 +331,8 @@ def bench_fns(label: str, sub_label: str, description: str, fns: list[Callable])
...
@@ -331,8 +331,8 @@ def bench_fns(label: str, sub_label: str, description: str, fns: list[Callable])
return
res
return
res
_SWEEP_SCHEDULES_RESULTS
:
Optional
[
pd
.
DataFrame
]
=
None
_SWEEP_SCHEDULES_RESULTS
:
pd
.
DataFrame
|
None
=
None
_SWEEP_SCHEDULES_RESULTS_CSV
:
Optional
[
str
]
=
None
_SWEEP_SCHEDULES_RESULTS_CSV
:
str
|
None
=
None
def
bench
(
def
bench
(
...
...
benchmarks/kernels/benchmark_marlin.py
View file @
41199996
...
@@ -44,7 +44,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
...
@@ -44,7 +44,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
sort_weights
,
sort_weights
,
)
)
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
[
"meta-llama/Llama-2-7b-hf/TP1"
]
DEFAULT_MODELS
=
[
"meta-llama/Llama-2-7b-hf/TP1"
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
]
...
@@ -263,7 +263,7 @@ def bench_run(
...
@@ -263,7 +263,7 @@ def bench_run(
results
.
append
(
results
.
append
(
benchmark
.
Timer
(
benchmark
.
Timer
(
stmt
=
"output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)"
,
# noqa: E501
stmt
=
"output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s,
None,
marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)"
,
# noqa: E501
globals
=
globals
,
globals
=
globals
,
label
=
label
,
label
=
label
,
sub_label
=
sub_label
,
sub_label
=
sub_label
,
...
@@ -273,7 +273,7 @@ def bench_run(
...
@@ -273,7 +273,7 @@ def bench_run(
results
.
append
(
results
.
append
(
benchmark
.
Timer
(
benchmark
.
Timer
(
stmt
=
"output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)"
,
# noqa: E501
stmt
=
"output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s,
None,
marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)"
,
# noqa: E501
globals
=
globals
,
globals
=
globals
,
label
=
label
,
label
=
label
,
sub_label
=
sub_label
,
sub_label
=
sub_label
,
...
...
benchmarks/kernels/benchmark_moe.py
View file @
41199996
...
@@ -21,7 +21,7 @@ from vllm.model_executor.layers.fused_moe.config import (
...
@@ -21,7 +21,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from
vllm.model_executor.layers.fused_moe.fused_moe
import
*
from
vllm.model_executor.layers.fused_moe.fused_moe
import
*
from
vllm.transformers_utils.config
import
get_config
from
vllm.transformers_utils.config
import
get_config
from
vllm.triton_utils
import
triton
from
vllm.triton_utils
import
triton
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
# 移除全局的 current_platform 导入,改为在需要时局部导入
# 移除全局的 current_platform 导入,改为在需要时局部导入
# FP8_DTYPE = current_platform.fp8_dtype()
# FP8_DTYPE = current_platform.fp8_dtype()
...
@@ -228,8 +228,8 @@ def benchmark_config(
...
@@ -228,8 +228,8 @@ def benchmark_config(
# run()
# run()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
start_event
=
torch
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
Event
(
enable_timing
=
True
)
latencies
:
list
[
float
]
=
[]
latencies
:
list
[
float
]
=
[]
for
i
in
range
(
num_iters
):
for
i
in
range
(
num_iters
):
...
@@ -253,10 +253,10 @@ def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False):
...
@@ -253,10 +253,10 @@ def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False):
block_k_range
=
[
32
,
64
,
128
,
256
]
block_k_range
=
[
32
,
64
,
128
,
256
]
if
not
use_fp16
:
if
not
use_fp16
:
block_k_range
.
remove
(
16
)
# BLOCK_K=16 not supported for fp8
block_k_range
.
remove
(
16
)
# BLOCK_K=16 not supported for fp8
num_warps_range
=
[
2
,
4
,
8
]
num_warps_range
=
[
1
,
2
,
4
,
8
]
group_m_range
=
[
1
,
16
,
32
,
64
]
group_m_range
=
[
1
,
4
,
8
,
16
,
32
]
num_stage_range
=
[
2
,
3
,
4
,
5
]
num_stage_range
=
[
2
]
# waves_per_eu_range = [0]
# waves_per_eu_range = [0
, 1, 2, 4
]
# matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
# matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
# kpack_range = [1, 2] if use_fp16 else []
# kpack_range = [1, 2] if use_fp16 else []
...
@@ -669,19 +669,23 @@ def main(args: argparse.Namespace):
...
@@ -669,19 +669,23 @@ def main(args: argparse.Namespace):
E
=
config
.
ffn_config
.
moe_num_experts
E
=
config
.
ffn_config
.
moe_num_experts
topk
=
config
.
ffn_config
.
moe_top_k
topk
=
config
.
ffn_config
.
moe_top_k
intermediate_size
=
config
.
ffn_config
.
ffn_hidden_size
intermediate_size
=
config
.
ffn_config
.
ffn_hidden_size
hidden_size
=
config
.
hidden_size
elif
config
.
architectures
[
0
]
==
"JambaForCausalLM"
:
elif
config
.
architectures
[
0
]
==
"JambaForCausalLM"
:
E
=
config
.
num_experts
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
intermediate_size
=
config
.
intermediate_size
hidden_size
=
config
.
hidden_size
elif
config
.
architectures
[
0
]
in
(
elif
config
.
architectures
[
0
]
in
(
"DeepseekV2ForCausalLM"
,
"DeepseekV2ForCausalLM"
,
"DeepseekV3ForCausalLM"
,
"DeepseekV3ForCausalLM"
,
"DeepseekV32ForCausalLM"
,
"DeepseekV32ForCausalLM"
,
"Glm4MoeForCausalLM"
,
"Glm4MoeForCausalLM"
,
"NemotronHForCausalLM"
,
):
):
E
=
config
.
n_routed_experts
E
=
config
.
n_routed_experts
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
intermediate_size
=
config
.
moe_intermediate_size
hidden_size
=
config
.
hidden_size
elif
config
.
architectures
[
0
]
in
(
elif
config
.
architectures
[
0
]
in
(
"Qwen2MoeForCausalLM"
,
"Qwen2MoeForCausalLM"
,
"Qwen3MoeForCausalLM"
,
"Qwen3MoeForCausalLM"
,
...
@@ -690,14 +694,27 @@ def main(args: argparse.Namespace):
...
@@ -690,14 +694,27 @@ def main(args: argparse.Namespace):
E
=
config
.
num_experts
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
intermediate_size
=
config
.
moe_intermediate_size
hidden_size
=
config
.
hidden_size
elif
config
.
architectures
[
0
]
==
"Qwen3VLMoeForConditionalGeneration"
:
text_config
=
config
.
get_text_config
()
E
=
text_config
.
num_experts
topk
=
text_config
.
num_experts_per_tok
intermediate_size
=
text_config
.
moe_intermediate_size
hidden_size
=
text_config
.
hidden_size
elif
config
.
architectures
[
0
]
in
(
"HunYuanMoEV1ForCausalLM"
):
elif
config
.
architectures
[
0
]
in
(
"HunYuanMoEV1ForCausalLM"
):
E
=
config
.
num_experts
E
=
config
.
num_experts
topk
=
config
.
moe_topk
[
0
]
topk
=
config
.
moe_topk
[
0
]
intermediate_size
=
config
.
moe_intermediate_size
[
0
]
intermediate_size
=
config
.
moe_intermediate_size
[
0
]
hidden_size
=
config
.
hidden_size
elif
config
.
architectures
[
0
]
in
(
"Step3VLForConditionalGeneration"
):
elif
config
.
architectures
[
0
]
in
(
"Step3VLForConditionalGeneration"
):
E
=
config
.
text_config
.
moe_num_experts
E
=
config
.
text_config
.
moe_num_experts
topk
=
config
.
text_config
.
moe_top_k
topk
=
config
.
text_config
.
moe_top_k
intermediate_size
=
config
.
text_config
.
moe_intermediate_size
intermediate_size
=
config
.
text_config
.
moe_intermediate_size
elif
config
.
architectures
[
0
]
in
[
"Qwen3OmniMoeForConditionalGeneration"
]:
E
=
config
.
thinker_config
.
text_config
.
num_experts
topk
=
config
.
thinker_config
.
text_config
.
num_experts_per_tok
intermediate_size
=
config
.
thinker_config
.
text_config
.
moe_intermediate_size
hidden_size
=
config
.
thinker_config
.
text_config
.
hidden_size
else
:
else
:
# Support for llama4
# Support for llama4
config
=
config
.
get_text_config
()
config
=
config
.
get_text_config
()
...
@@ -705,16 +722,16 @@ def main(args: argparse.Namespace):
...
@@ -705,16 +722,16 @@ def main(args: argparse.Namespace):
E
=
config
.
num_local_experts
E
=
config
.
num_local_experts
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
intermediate_size
=
config
.
intermediate_size
hidden_size
=
config
.
hidden_size
enable_ep
=
bool
(
args
.
enable_expert_parallel
)
enable_ep
=
bool
(
args
.
enable_expert_parallel
)
if
enable_ep
:
if
enable_ep
:
ensure_divisibility
(
E
,
tp_size
,
"Number of experts"
)
ensure_divisibility
(
E
,
tp_size
,
"Number of experts"
)
E
=
E
//
tp_size
E
=
E
//
tp_size
shard_intermediate_size
=
2
*
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
else
:
else
:
ensure_divisibility
(
intermediate_size
,
tp_size
,
"intermediate_size"
)
ensure_divisibility
(
intermediate_size
,
args
.
tp_size
,
"intermediate_size"
)
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
hidden_size
=
config
.
hidden_size
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
dtype
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
torch_dtype
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
block_quant_shape
=
get_weight_block_size_safety
(
config
)
block_quant_shape
=
get_weight_block_size_safety
(
config
)
...
...
benchmarks/kernels/benchmark_moe_permute_unpermute.py
View file @
41199996
...
@@ -17,7 +17,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
...
@@ -17,7 +17,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
)
)
from
vllm.model_executor.layers.fused_moe.utils
import
_fp8_quantize
from
vllm.model_executor.layers.fused_moe.utils
import
_fp8_quantize
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
...
@@ -105,8 +105,8 @@ def benchmark_permute(
...
@@ -105,8 +105,8 @@ def benchmark_permute(
graph
.
replay
()
graph
.
replay
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
start_event
=
torch
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
Event
(
enable_timing
=
True
)
latencies
:
list
[
float
]
=
[]
latencies
:
list
[
float
]
=
[]
for
i
in
range
(
num_iters
):
for
i
in
range
(
num_iters
):
...
@@ -241,8 +241,8 @@ def benchmark_unpermute(
...
@@ -241,8 +241,8 @@ def benchmark_unpermute(
graph
.
replay
()
graph
.
replay
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
start_event
=
torch
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
Event
(
enable_timing
=
True
)
latencies
:
list
[
float
]
=
[]
latencies
:
list
[
float
]
=
[]
for
i
in
range
(
num_iters
):
for
i
in
range
(
num_iters
):
...
@@ -344,7 +344,7 @@ def main(args: argparse.Namespace):
...
@@ -344,7 +344,7 @@ def main(args: argparse.Namespace):
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
hidden_size
=
config
.
hidden_size
hidden_size
=
config
.
hidden_size
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
torch_
dtype
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
dtype
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
use_customized_permute
=
args
.
use_customized_permute
use_customized_permute
=
args
.
use_customized_permute
...
...
benchmarks/kernels/benchmark_mrope.py
View file @
41199996
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#
#
# The CSV file (named with current date/time) contains these columns:
# The CSV file (named with current date/time) contains these columns:
# model_name, tp_size, num_tokens, num_heads, num_kv_heads, head_dim, max_position,
# model_name, tp_size, num_tokens, num_heads, num_kv_heads, head_dim, max_position,
#
rope_theta,
is_neox_style, rope_
scaling
, dtype, torch_mean, torch_median, torch_p99,
# is_neox_style, rope_
parameters
, dtype, torch_mean, torch_median, torch_p99,
# torch_min, torch_max, triton_mean, triton_median, triton_p99, triton_min, triton_max,
# torch_min, torch_max, triton_mean, triton_median, triton_p99, triton_min, triton_max,
# speedup
# speedup
#
#
...
@@ -39,7 +39,7 @@ import torch
...
@@ -39,7 +39,7 @@ import torch
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.transformers_utils.config
import
get_config
from
vllm.transformers_utils.config
import
get_config
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
...
@@ -86,9 +86,8 @@ def benchmark_mrope(
...
@@ -86,9 +86,8 @@ def benchmark_mrope(
num_heads
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
num_kv_heads
:
int
,
max_position
:
int
=
8192
,
max_position
:
int
=
8192
,
rope_theta
:
float
=
10000
,
is_neox_style
:
bool
=
True
,
is_neox_style
:
bool
=
True
,
rope_
scaling
:
dict
[
str
,
Any
]
=
None
,
rope_
parameters
:
dict
[
str
,
Any
]
|
None
=
None
,
dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
seed
:
int
=
0
,
seed
:
int
=
0
,
warmup_iter
:
int
=
10
,
warmup_iter
:
int
=
10
,
...
@@ -102,9 +101,8 @@ def benchmark_mrope(
...
@@ -102,9 +101,8 @@ def benchmark_mrope(
head_size
=
head_dim
,
head_size
=
head_dim
,
rotary_dim
=
head_dim
,
rotary_dim
=
head_dim
,
max_position
=
max_position
,
max_position
=
max_position
,
base
=
rope_theta
,
is_neox_style
=
is_neox_style
,
is_neox_style
=
is_neox_style
,
rope_
scaling
=
rope_scaling
,
rope_
parameters
=
rope_parameters
,
dtype
=
dtype
,
dtype
=
dtype
,
).
to
(
device
=
device
)
).
to
(
device
=
device
)
...
@@ -203,9 +201,8 @@ def benchmark_mrope(
...
@@ -203,9 +201,8 @@ def benchmark_mrope(
num_kv_heads
,
num_kv_heads
,
head_dim
,
head_dim
,
max_position
,
max_position
,
rope_theta
,
is_neox_style
,
is_neox_style
,
str
(
rope_
scaling
),
str
(
rope_
parameters
),
str
(
dtype
).
split
(
"."
)[
-
1
],
str
(
dtype
).
split
(
"."
)[
-
1
],
torch_stats
[
"mean"
],
torch_stats
[
"mean"
],
torch_stats
[
"median"
],
torch_stats
[
"median"
],
...
@@ -255,9 +252,8 @@ if __name__ == "__main__":
...
@@ -255,9 +252,8 @@ if __name__ == "__main__":
"num_kv_heads"
,
"num_kv_heads"
,
"head_dim"
,
"head_dim"
,
"max_position"
,
"max_position"
,
"rope_theta"
,
"is_neox_style"
,
"is_neox_style"
,
"rope_
scaling
"
,
"rope_
parameters
"
,
"dtype"
,
"dtype"
,
"torch_mean"
,
"torch_mean"
,
"torch_median"
,
"torch_median"
,
...
@@ -303,7 +299,7 @@ if __name__ == "__main__":
...
@@ -303,7 +299,7 @@ if __name__ == "__main__":
q_size
=
num_heads
*
head_dim
q_size
=
num_heads
*
head_dim
kv_size
=
num_kv_heads
*
head_dim
kv_size
=
num_kv_heads
*
head_dim
is_neox_style
=
True
is_neox_style
=
True
rope_
theta
=
config
.
rope_
theta
rope_
parameters
=
config
.
rope_
parameters
max_position
=
config
.
max_position_embeddings
max_position
=
config
.
max_position_embeddings
for
num_tokens
in
num_tokens_list
:
for
num_tokens
in
num_tokens_list
:
...
@@ -315,9 +311,8 @@ if __name__ == "__main__":
...
@@ -315,9 +311,8 @@ if __name__ == "__main__":
num_heads
=
num_heads
,
num_heads
=
num_heads
,
num_kv_heads
=
num_kv_heads
,
num_kv_heads
=
num_kv_heads
,
max_position
=
max_position
,
max_position
=
max_position
,
rope_theta
=
rope_theta
,
is_neox_style
=
is_neox_style
,
is_neox_style
=
is_neox_style
,
rope_
scaling
=
config
.
rope_scaling
,
rope_
parameters
=
rope_parameters
,
dtype
=
getattr
(
torch
,
args
.
dtype
),
dtype
=
getattr
(
torch
,
args
.
dtype
),
seed
=
args
.
seed
,
seed
=
args
.
seed
,
warmup_iter
=
args
.
warmup_iter
,
warmup_iter
=
args
.
warmup_iter
,
...
...
benchmarks/kernels/benchmark_paged_attention.py
View file @
41199996
...
@@ -3,20 +3,18 @@
...
@@ -3,20 +3,18 @@
import
random
import
random
import
time
import
time
from
typing
import
Optional
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
import
vllm.envs
as
envs
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
(
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
,
create_kv_caches_with_random
,
create_kv_caches_with_random
,
)
)
import
vllm.envs
as
envs
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -39,7 +37,7 @@ def main(
...
@@ -39,7 +37,7 @@ def main(
seed
:
int
,
seed
:
int
,
do_profile
:
bool
,
do_profile
:
bool
,
device
:
str
=
"cuda"
,
device
:
str
=
"cuda"
,
kv_cache_dtype
:
Optional
[
str
]
=
None
,
kv_cache_dtype
:
str
|
None
=
None
,
)
->
None
:
)
->
None
:
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
...
...
benchmarks/kernels/benchmark_per_token_group_quant.py
View file @
41199996
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
import
argparse
import
argparse
import
math
import
math
from
collections.abc
import
Callable
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Callable
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
torch
import
torch
...
@@ -30,8 +30,8 @@ def _time_cuda(
...
@@ -30,8 +30,8 @@ def _time_cuda(
fn
()
fn
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
start
=
torch
.
Event
(
enable_timing
=
True
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end
=
torch
.
Event
(
enable_timing
=
True
)
start
.
record
()
start
.
record
()
for
_
in
range
(
bench_iters
):
for
_
in
range
(
bench_iters
):
...
...
benchmarks/kernels/benchmark_polynorm.py
deleted
100644 → 0
View file @
31021d81
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
import
torch
from
vllm
import
_custom_ops
as
vllm_ops
from
vllm.triton_utils
import
triton
def
polynorm_naive
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
):
orig_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
def
norm
(
x
,
eps
:
float
):
return
x
/
torch
.
sqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
eps
)
x
=
x
.
float
()
return
(
(
weight
[
0
]
*
norm
(
x
**
3
,
eps
)
+
weight
[
1
]
*
norm
(
x
**
2
,
eps
)
+
weight
[
2
]
*
norm
(
x
,
eps
)
+
bias
)
.
to
(
weight
.
dtype
)
.
view
(
orig_shape
)
)
def
polynorm_vllm
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
):
orig_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
out
=
torch
.
empty_like
(
x
)
vllm_ops
.
poly_norm
(
out
,
x
,
weight
,
bias
,
eps
)
output
=
out
output
=
output
.
view
(
orig_shape
)
return
output
def
calculate_diff
(
batch_size
,
seq_len
,
hidden_dim
):
dtype
=
torch
.
bfloat16
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_dim
,
dtype
=
dtype
,
device
=
"cuda"
)
weight
=
torch
.
ones
(
3
,
dtype
=
dtype
,
device
=
"cuda"
)
bias
=
torch
.
ones
(
1
,
dtype
=
dtype
,
device
=
"cuda"
)
output_naive
=
polynorm_naive
(
x
,
weight
,
bias
)
output_vllm
=
polynorm_vllm
(
x
,
weight
,
bias
)
if
torch
.
allclose
(
output_naive
,
output_vllm
,
atol
=
1e-2
,
rtol
=
1e-2
):
print
(
"✅ All implementations match"
)
else
:
print
(
"❌ Implementations differ"
)
batch_size_range
=
[
2
**
i
for
i
in
range
(
0
,
7
,
2
)]
seq_length_range
=
[
2
**
i
for
i
in
range
(
6
,
11
,
1
)]
dim_range
=
[
2048
,
4096
]
configs
=
list
(
itertools
.
product
(
dim_range
,
batch_size_range
,
seq_length_range
))
def
get_benchmark
():
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"dim"
,
"batch_size"
,
"seq_len"
],
x_vals
=
[
list
(
_
)
for
_
in
configs
],
line_arg
=
"provider"
,
line_vals
=
[
"naive"
,
"vllm"
],
line_names
=
[
"Naive"
,
"vLLM"
],
styles
=
[(
"blue"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"polynorm-perf"
,
args
=
{},
)
)
def
benchmark
(
dim
,
batch_size
,
seq_len
,
provider
):
dtype
=
torch
.
bfloat16
hidden_dim
=
dim
*
4
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_dim
,
dtype
=
dtype
,
device
=
"cuda"
)
weight
=
torch
.
ones
(
3
,
dtype
=
dtype
,
device
=
"cuda"
)
bias
=
torch
.
ones
(
1
,
dtype
=
dtype
,
device
=
"cuda"
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"naive"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
polynorm_naive
(
x
,
weight
,
bias
),
quantiles
=
quantiles
,
)
else
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
polynorm_vllm
(
x
,
weight
,
bias
),
quantiles
=
quantiles
,
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
return
benchmark
if
__name__
==
"__main__"
:
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
4
,
help
=
"Batch size"
,
)
parser
.
add_argument
(
"--seq-len"
,
type
=
int
,
default
=
128
,
help
=
"Sequence length"
,
)
parser
.
add_argument
(
"--hidden-dim"
,
type
=
int
,
default
=
8192
,
help
=
"Intermediate size of MLP"
,
)
parser
.
add_argument
(
"--save-path"
,
type
=
str
,
default
=
"./configs/polnorm/"
,
help
=
"Path to save polnorm benchmark results"
,
)
args
=
parser
.
parse_args
()
# Run correctness test
calculate_diff
(
batch_size
=
args
.
batch_size
,
seq_len
=
args
.
seq_len
,
hidden_dim
=
args
.
hidden_dim
,
)
benchmark
=
get_benchmark
()
# Run performance benchmark
benchmark
.
run
(
print_data
=
True
,
save_path
=
args
.
save_path
)
benchmarks/kernels/benchmark_quant.py
View file @
41199996
...
@@ -7,7 +7,8 @@ import torch
...
@@ -7,7 +7,8 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
...
benchmarks/kernels/benchmark_reshape_and_cache.py
0 → 100644
View file @
41199996
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
random
import
time
import
torch
from
tabulate
import
tabulate
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.torch_utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
create_kv_caches_with_random
,
)
logger
=
init_logger
(
__name__
)
@
torch
.
inference_mode
()
def
run_benchmark
(
num_tokens
:
int
,
num_heads
:
int
,
head_size
:
int
,
block_size
:
int
,
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
str
,
num_iters
:
int
,
benchmark_mode
:
str
,
device
:
str
=
"cuda"
,
)
->
float
:
"""Return latency (seconds) for given num_tokens."""
if
kv_cache_dtype
==
"fp8"
and
head_size
%
16
:
raise
ValueError
(
"fp8 kv-cache requires head_size to be a multiple of 16."
)
current_platform
.
seed_everything
(
42
)
torch
.
set_default_device
(
device
)
# create random key / value tensors [T, H, D].
key
=
torch
.
randn
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
device
)
value
=
torch
.
randn_like
(
key
)
# prepare the slot mapping.
# each token is assigned a unique slot in the KV-cache.
num_slots
=
block_size
*
num_blocks
if
num_tokens
>
num_slots
:
raise
ValueError
(
"num_tokens cannot exceed the total number of cache slots"
)
slot_mapping_lst
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping_lst
,
dtype
=
torch
.
long
,
device
=
device
)
key_caches
,
value_caches
=
create_kv_caches_with_random
(
num_blocks
,
block_size
,
1
,
# num_layers
num_heads
,
head_size
,
kv_cache_dtype
,
dtype
,
device
=
device
,
)
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# to free unused memory
del
key_caches
,
value_caches
# compute per-kernel scaling factors for fp8 conversion (if used).
k_scale
=
(
key
.
amax
()
/
64.0
).
to
(
torch
.
float32
)
v_scale
=
(
value
.
amax
()
/
64.0
).
to
(
torch
.
float32
)
function_under_test
=
lambda
:
ops
.
reshape_and_cache
(
key
,
# noqa: F821
value
,
# noqa: F821
key_cache
,
# noqa: F821
value_cache
,
# noqa: F821
slot_mapping
,
# noqa: F821
kv_cache_dtype
,
k_scale
,
v_scale
,
)
if
benchmark_mode
==
"cudagraph"
:
g
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
g
):
function_under_test
()
torch
.
cuda
.
synchronize
()
function_under_test
=
lambda
:
g
.
replay
()
def
run_cuda_benchmark
(
n_iters
:
int
)
->
float
:
nonlocal
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
torch
.
cuda
.
synchronize
()
start
=
time
.
perf_counter
()
for
_
in
range
(
n_iters
):
function_under_test
()
torch
.
cuda
.
synchronize
()
end
=
time
.
perf_counter
()
return
(
end
-
start
)
/
n_iters
# warm-up
run_cuda_benchmark
(
3
)
lat
=
run_cuda_benchmark
(
num_iters
)
# free tensors to mitigate OOM when sweeping
del
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
torch
.
cuda
.
empty_cache
()
return
lat
def
main
(
args
):
rows
=
[]
for
exp
in
range
(
1
,
17
):
n_tok
=
2
**
exp
lat
=
run_benchmark
(
num_tokens
=
n_tok
,
num_heads
=
args
.
num_heads
,
head_size
=
args
.
head_size
,
block_size
=
args
.
block_size
,
num_blocks
=
args
.
num_blocks
,
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
args
.
dtype
],
kv_cache_dtype
=
args
.
kv_cache_dtype
,
num_iters
=
args
.
iters
,
benchmark_mode
=
args
.
mode
,
device
=
"cuda"
,
)
rows
.
append
([
n_tok
,
lat
*
1e6
])
# convert to microseconds
print
(
f
"Benchmark results for implementation cuda (measuring with
{
args
.
mode
}
):"
)
print
(
tabulate
(
rows
,
headers
=
[
"num_tokens"
,
"latency (µs)"
],
floatfmt
=
".3f"
))
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
"--num-heads"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--head-size"
,
type
=
int
,
choices
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
],
default
=
128
,
)
parser
.
add_argument
(
"--block-size"
,
type
=
int
,
choices
=
[
16
,
32
],
default
=
16
)
parser
.
add_argument
(
"--num-blocks"
,
type
=
int
,
default
=
128
*
128
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"bfloat16"
,
)
parser
.
add_argument
(
"--kv-cache-dtype"
,
type
=
str
,
choices
=
[
"auto"
,
"fp8"
],
default
=
"auto"
,
)
parser
.
add_argument
(
"--iters"
,
type
=
int
,
default
=
200
)
parser
.
add_argument
(
"--mode"
,
type
=
str
,
choices
=
[
"cudagraph"
,
"no_graph"
],
default
=
"cudagraph"
,
)
args
=
parser
.
parse_args
()
main
(
args
)
Prev
1
2
3
4
5
6
7
8
9
10
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment