Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4589b940
Unverified
Commit
4589b940
authored
Jun 10, 2025
by
Tianyu Guo
Committed by
GitHub
Jun 09, 2025
Browse files
[Bugfix] Fix benchmark_moe.py (#19016)
Signed-off-by:
Tianyu Guo
<
guoty9@mail2.sysu.edu.cn
>
parent
cc867be1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
10 deletions
+4
-10
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+4
-10
No files found.
benchmarks/kernels/benchmark_moe.py
View file @
4589b940
...
@@ -7,7 +7,6 @@ import time
...
@@ -7,7 +7,6 @@ import time
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
from
datetime
import
datetime
from
datetime
import
datetime
from
itertools
import
product
from
itertools
import
product
from
types
import
SimpleNamespace
from
typing
import
Any
,
TypedDict
from
typing
import
Any
,
TypedDict
import
ray
import
ray
...
@@ -43,7 +42,7 @@ def benchmark_config(
...
@@ -43,7 +42,7 @@ def benchmark_config(
use_fp8_w8a8
:
bool
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int8_w8a16
:
bool
,
num_iters
:
int
=
100
,
num_iters
:
int
=
100
,
block_quant_shape
:
L
ist
[
int
]
=
None
,
block_quant_shape
:
l
ist
[
int
]
=
None
,
use_deep_gemm
:
bool
=
False
,
use_deep_gemm
:
bool
=
False
,
)
->
float
:
)
->
float
:
init_dtype
=
torch
.
float16
if
use_fp8_w8a8
else
dtype
init_dtype
=
torch
.
float16
if
use_fp8_w8a8
else
dtype
...
@@ -400,7 +399,7 @@ class BenchmarkWorker:
...
@@ -400,7 +399,7 @@ class BenchmarkWorker:
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int8_w8a16
:
bool
,
block_quant_shape
:
L
ist
[
int
]
=
None
,
block_quant_shape
:
l
ist
[
int
]
=
None
,
use_deep_gemm
:
bool
=
False
,
use_deep_gemm
:
bool
=
False
,
)
->
tuple
[
dict
[
str
,
int
],
float
]:
)
->
tuple
[
dict
[
str
,
int
],
float
]:
current_platform
.
seed_everything
(
self
.
seed
)
current_platform
.
seed_everything
(
self
.
seed
)
...
@@ -532,7 +531,7 @@ def save_configs(
...
@@ -532,7 +531,7 @@ def save_configs(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int8_w8a16
:
bool
,
block_quant_shape
:
L
ist
[
int
],
block_quant_shape
:
l
ist
[
int
],
)
->
None
:
)
->
None
:
dtype_str
=
get_config_dtype_str
(
dtype_str
=
get_config_dtype_str
(
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
...
@@ -563,7 +562,6 @@ def main(args: argparse.Namespace):
...
@@ -563,7 +562,6 @@ def main(args: argparse.Namespace):
config
=
get_config
(
model
=
args
.
model
,
trust_remote_code
=
args
.
trust_remote_code
)
config
=
get_config
(
model
=
args
.
model
,
trust_remote_code
=
args
.
trust_remote_code
)
if
args
.
model_prefix
:
if
args
.
model_prefix
:
config
=
getattr
(
config
,
args
.
model_prefix
)
config
=
getattr
(
config
,
args
.
model_prefix
)
config
=
SimpleNamespace
(
**
config
)
if
config
.
architectures
[
0
]
==
"DbrxForCausalLM"
:
if
config
.
architectures
[
0
]
==
"DbrxForCausalLM"
:
E
=
config
.
ffn_config
.
moe_num_experts
E
=
config
.
ffn_config
.
moe_num_experts
...
@@ -595,11 +593,7 @@ def main(args: argparse.Namespace):
...
@@ -595,11 +593,7 @@ def main(args: argparse.Namespace):
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
hidden_size
=
config
.
hidden_size
hidden_size
=
config
.
hidden_size
dtype
=
(
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
torch_dtype
torch
.
float16
if
current_platform
.
is_rocm
()
else
getattr
(
torch
,
config
.
torch_dtype
)
)
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
block_quant_shape
=
get_weight_block_size_safety
(
config
)
block_quant_shape
=
get_weight_block_size_safety
(
config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment