Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cd3ed273
"csrc/vscode:/vscode.git/clone" did not exist on "98229db2444b016d572bfd36e685960b3352d900"
Commit
cd3ed273
authored
Jul 29, 2025
by
zhuwenwen
Browse files
update benchmark_moe.py
parent
be0549c4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
68 additions
and
24 deletions
+68
-24
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+68
-24
No files found.
benchmarks/kernels/benchmark_moe.py
View file @
cd3ed273
...
@@ -7,19 +7,19 @@ import time
...
@@ -7,19 +7,19 @@ import time
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
from
datetime
import
datetime
from
datetime
import
datetime
from
itertools
import
product
from
itertools
import
product
from
typing
import
Any
,
TypedDict
from
typing
import
Any
,
TypedDict
,
Optional
import
ray
import
ray
import
torch
import
torch
from
ray.experimental.tqdm_ray
import
tqdm
from
ray.experimental.tqdm_ray
import
tqdm
from
vllm.model_executor.layers.fused_moe.fused_moe
import
*
from
vllm.model_executor.layers.fused_moe.fused_moe
import
*
from
vllm.platforms
import
current_platform
from
vllm.transformers_utils.config
import
get_config
from
vllm.transformers_utils.config
import
get_config
from
vllm.triton_utils
import
triton
from
vllm.triton_utils
import
triton
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
# 移除全局的 current_platform 导入,改为在需要时局部导入
# FP8_DTYPE = current_platform.fp8_dtype()
class
BenchmarkConfig
(
TypedDict
):
class
BenchmarkConfig
(
TypedDict
):
...
@@ -47,8 +47,12 @@ def benchmark_config(
...
@@ -47,8 +47,12 @@ def benchmark_config(
use_deep_gemm
:
bool
=
False
,
use_deep_gemm
:
bool
=
False
,
nn_moe
:
Optional
[
bool
]
=
False
nn_moe
:
Optional
[
bool
]
=
False
)
->
float
:
)
->
float
:
from
vllm.platforms
import
current_platform
device
=
torch
.
cuda
.
current_device
()
init_dtype
=
torch
.
float16
if
use_fp8_w8a8
else
dtype
init_dtype
=
torch
.
float16
if
use_fp8_w8a8
else
dtype
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
device
)
if
use_int8_w8a16
:
if
use_int8_w8a16
:
if
not
nn_moe
:
if
not
nn_moe
:
w1
=
torch
.
randint
(
w1
=
torch
.
randint
(
...
@@ -60,6 +64,7 @@ def benchmark_config(
...
@@ -60,6 +64,7 @@ def benchmark_config(
hidden_size
,
hidden_size
,
),
),
dtype
=
torch
.
int8
,
dtype
=
torch
.
int8
,
device
=
device
,
)
)
w2
=
torch
.
randint
(
w2
=
torch
.
randint
(
-
127
,
-
127
,
...
@@ -70,6 +75,7 @@ def benchmark_config(
...
@@ -70,6 +75,7 @@ def benchmark_config(
shard_intermediate_size
//
2
,
shard_intermediate_size
//
2
,
),
),
dtype
=
torch
.
int8
,
dtype
=
torch
.
int8
,
device
=
device
,
)
)
else
:
else
:
w1
=
torch
.
randint
(
w1
=
torch
.
randint
(
...
@@ -81,6 +87,7 @@ def benchmark_config(
...
@@ -81,6 +87,7 @@ def benchmark_config(
shard_intermediate_size
,
shard_intermediate_size
,
),
),
dtype
=
torch
.
int8
,
dtype
=
torch
.
int8
,
device
=
device
,
)
)
w2
=
torch
.
randint
(
w2
=
torch
.
randint
(
-
127
,
-
127
,
...
@@ -91,23 +98,24 @@ def benchmark_config(
...
@@ -91,23 +98,24 @@ def benchmark_config(
hidden_size
,
hidden_size
,
),
),
dtype
=
torch
.
int8
,
dtype
=
torch
.
int8
,
device
=
device
,
)
)
else
:
else
:
if
not
nn_moe
:
if
not
nn_moe
:
w1
=
torch
.
randn
(
w1
=
torch
.
randn
(
num_experts
,
shard_intermediate_size
,
hidden_size
,
dtype
=
init_dtype
num_experts
,
shard_intermediate_size
,
hidden_size
,
dtype
=
init_dtype
,
device
=
device
)
)
w2
=
torch
.
randn
(
w2
=
torch
.
randn
(
num_experts
,
hidden_size
,
shard_intermediate_size
//
2
,
dtype
=
init_dtype
num_experts
,
hidden_size
,
shard_intermediate_size
//
2
,
dtype
=
init_dtype
,
device
=
device
)
)
else
:
else
:
w1
=
torch
.
randn
(
w1
=
torch
.
randn
(
num_experts
,
hidden_size
,
shard_intermediate_size
,
dtype
=
init_dtype
num_experts
,
hidden_size
,
shard_intermediate_size
,
dtype
=
init_dtype
,
device
=
device
)
)
w2
=
torch
.
randn
(
w2
=
torch
.
randn
(
num_experts
,
shard_intermediate_size
//
2
,
hidden_size
,
dtype
=
init_dtype
num_experts
,
shard_intermediate_size
//
2
,
hidden_size
,
dtype
=
init_dtype
,
device
=
device
)
)
gating_output
=
torch
.
randn
(
num_iters
,
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
gating_output
=
torch
.
randn
(
num_iters
,
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
,
device
=
device
)
w1_scale
=
None
w1_scale
=
None
w2_scale
=
None
w2_scale
=
None
...
@@ -115,9 +123,9 @@ def benchmark_config(
...
@@ -115,9 +123,9 @@ def benchmark_config(
a2_scale
=
None
a2_scale
=
None
if
use_int8_w8a16
:
if
use_int8_w8a16
:
w1_scale
=
torch
.
randn
(
w1_scale
=
torch
.
randn
(
(
num_experts
,
2
*
shard_intermediate_size
),
dtype
=
torch
.
float32
(
num_experts
,
2
*
shard_intermediate_size
),
dtype
=
torch
.
float32
,
device
=
device
)
)
w2_scale
=
torch
.
randn
((
hidden_size
,
num_experts
),
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
((
hidden_size
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
device
)
if
use_fp8_w8a8
:
if
use_fp8_w8a8
:
if
block_quant_shape
:
if
block_quant_shape
:
block_n
,
block_k
=
block_quant_shape
[
0
],
block_quant_shape
[
1
]
block_n
,
block_k
=
block_quant_shape
[
0
],
block_quant_shape
[
1
]
...
@@ -130,24 +138,26 @@ def benchmark_config(
...
@@ -130,24 +138,26 @@ def benchmark_config(
k_tiles_w1
=
(
K
+
block_k
-
1
)
//
block_k
k_tiles_w1
=
(
K
+
block_k
-
1
)
//
block_k
k_tiles_w2
=
(
N
+
block_k
-
1
)
//
block_k
k_tiles_w2
=
(
N
+
block_k
-
1
)
//
block_k
w1_scale
=
(
w1_scale
=
(
torch
.
rand
((
E
,
n_tiles_w1
,
k_tiles_w1
),
dtype
=
torch
.
float32
)
torch
.
rand
((
E
,
n_tiles_w1
,
k_tiles_w1
),
dtype
=
torch
.
float32
,
device
=
device
)
*
factor_for_scale
*
factor_for_scale
)
)
w2_scale
=
(
w2_scale
=
(
torch
.
rand
((
E
,
n_tiles_w2
,
k_tiles_w2
),
dtype
=
torch
.
float32
)
torch
.
rand
((
E
,
n_tiles_w2
,
k_tiles_w2
),
dtype
=
torch
.
float32
,
device
=
device
)
*
factor_for_scale
*
factor_for_scale
)
)
else
:
else
:
w1_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
w1_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
,
device
=
device
)
w2_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
,
device
=
device
)
a1_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
a1_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
a2_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
a2_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
# 获取 FP8_DTYPE
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
w1
=
w1
.
to
(
FP8_DTYPE
)
w1
=
w1
.
to
(
FP8_DTYPE
)
w2
=
w2
.
to
(
FP8_DTYPE
)
w2
=
w2
.
to
(
FP8_DTYPE
)
input_gating
=
torch
.
empty
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
input_gating
=
torch
.
empty
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
,
device
=
device
)
def
prepare
(
i
:
int
):
def
prepare
(
i
:
int
):
input_gating
.
copy_
(
gating_output
[
i
])
input_gating
.
copy_
(
gating_output
[
i
])
...
@@ -266,6 +276,9 @@ def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False):
...
@@ -266,6 +276,9 @@ def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False):
def
get_configs_compute_bound
(
use_fp16
,
block_quant_shape
,
nn_moe
:
Optional
[
bool
]
=
False
)
->
list
[
dict
[
str
,
int
]]:
def
get_configs_compute_bound
(
use_fp16
,
block_quant_shape
,
nn_moe
:
Optional
[
bool
]
=
False
)
->
list
[
dict
[
str
,
int
]]:
configs
:
list
[
BenchmarkConfig
]
=
[]
configs
:
list
[
BenchmarkConfig
]
=
[]
# 局部导入 current_platform
from
vllm.platforms
import
current_platform
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
param_ranges
=
get_rocm_tuning_space
(
use_fp16
,
nn_moe
)
param_ranges
=
get_rocm_tuning_space
(
use_fp16
,
nn_moe
)
...
@@ -426,12 +439,18 @@ def merge_unique_dicts(list1, list2):
...
@@ -426,12 +439,18 @@ def merge_unique_dicts(list1, list2):
@
ray
.
remote
(
num_gpus
=
1
)
@
ray
.
remote
(
num_gpus
=
1
)
class
BenchmarkWorker
:
class
BenchmarkWorker
:
def
__init__
(
self
,
seed
:
int
,
device_id
:
int
)
->
None
:
def
__init__
(
self
,
seed
:
int
,
device_id
:
int
)
->
None
:
torch
.
set_default_device
(
"cuda:"
+
str
(
device_id
))
from
vllm.platforms
import
current_platform
import
os
if
current_platform
.
is_rocm
():
# In ROCm environment with Ray, let Ray handle device assignment
# Don't manually set default device as it may conflict with Ray's device mapping
pass
else
:
torch
.
set_default_device
(
"cuda:"
+
str
(
device_id
))
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
self
.
seed
=
seed
self
.
seed
=
seed
# Get the device ID to allocate tensors and kernels
# Store the logical device ID for Ray
# on the respective GPU. This is required for Ray to work
# correctly with multi-GPU tuning on the ROCm platform.
self
.
device_id
=
device_id
self
.
device_id
=
device_id
def
benchmark
(
def
benchmark
(
...
@@ -448,7 +467,13 @@ class BenchmarkWorker:
...
@@ -448,7 +467,13 @@ class BenchmarkWorker:
use_deep_gemm
:
bool
=
False
,
use_deep_gemm
:
bool
=
False
,
nn_moe
:
Optional
[
bool
]
=
False
,
nn_moe
:
Optional
[
bool
]
=
False
,
)
->
tuple
[
dict
[
str
,
int
],
float
]:
)
->
tuple
[
dict
[
str
,
int
],
float
]:
# 局部导入 current_platform
from
vllm.platforms
import
current_platform
current_platform
.
seed_everything
(
self
.
seed
)
current_platform
.
seed_everything
(
self
.
seed
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
get_config_dtype_str
,
get_moe_configs
,
get_default_config
)
dtype_str
=
get_config_dtype_str
(
dtype_str
=
get_config_dtype_str
(
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
)
)
...
@@ -502,6 +527,9 @@ class BenchmarkWorker:
...
@@ -502,6 +527,9 @@ class BenchmarkWorker:
use_deep_gemm
:
bool
,
use_deep_gemm
:
bool
,
nn_moe
:
Optional
[
bool
]
=
False
,
nn_moe
:
Optional
[
bool
]
=
False
,
)
->
dict
[
str
,
int
]:
)
->
dict
[
str
,
int
]:
from
vllm.platforms
import
current_platform
import
os
best_config
=
None
best_config
=
None
best_time
=
float
(
"inf"
)
best_time
=
float
(
"inf"
)
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
...
@@ -515,10 +543,16 @@ class BenchmarkWorker:
...
@@ -515,10 +543,16 @@ class BenchmarkWorker:
topk
,
topk
,
)
)
# In ROCm environments with Ray, device context is already handled by Ray
# Using torch.cuda.device() may cause device ordinal conflicts
need_device_guard
=
False
need_device_guard
=
False
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
visible_device
=
os
.
environ
.
get
(
"ROCR_VISIBLE_DEVICES"
,
None
)
# For ROCm with Ray, skip additional device context management
if
visible_device
!=
f
"
{
self
.
device_id
}
"
:
need_device_guard
=
False
else
:
# For other platforms, use device guard if needed
visible_devices
=
os
.
environ
.
get
(
"CUDA_VISIBLE_DEVICES"
,
None
)
if
visible_devices
is
not
None
and
len
(
visible_devices
.
split
(
','
))
>
1
:
need_device_guard
=
True
need_device_guard
=
True
with
torch
.
cuda
.
device
(
self
.
device_id
)
if
need_device_guard
else
nullcontext
():
with
torch
.
cuda
.
device
(
self
.
device_id
)
if
need_device_guard
else
nullcontext
():
...
@@ -587,6 +621,10 @@ def save_configs(
...
@@ -587,6 +621,10 @@ def save_configs(
block_quant_shape
:
list
[
int
],
block_quant_shape
:
list
[
int
],
use_nn_moe
:
Optional
[
bool
]
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
None
:
)
->
None
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
get_config_dtype_str
,
get_config_file_name
)
dtype_str
=
get_config_dtype_str
(
dtype_str
=
get_config_dtype_str
(
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
)
)
...
@@ -611,6 +649,12 @@ def get_weight_block_size_safety(config, default_value=None):
...
@@ -611,6 +649,12 @@ def get_weight_block_size_safety(config, default_value=None):
def
main
(
args
:
argparse
.
Namespace
):
def
main
(
args
:
argparse
.
Namespace
):
import
os
import
logging
from
vllm.platforms
import
current_platform
logger
=
logging
.
getLogger
(
__name__
)
print
(
args
)
print
(
args
)
tp_size
=
args
.
tp_size
tp_size
=
args
.
tp_size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment