Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4589b940
Unverified
Commit
4589b940
authored
Jun 10, 2025
by
Tianyu Guo
Committed by
GitHub
Jun 09, 2025
Browse files
[Bugfix] Fix benchmark_moe.py (#19016)
Signed-off-by:
Tianyu Guo
<
guoty9@mail2.sysu.edu.cn
>
parent
cc867be1
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
10 deletions
+4
-10
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+4
-10
No files found.
benchmarks/kernels/benchmark_moe.py
View file @
4589b940
...
...
@@ -7,7 +7,6 @@ import time
from
contextlib
import
nullcontext
from
datetime
import
datetime
from
itertools
import
product
from
types
import
SimpleNamespace
from
typing
import
Any
,
TypedDict
import
ray
...
...
@@ -43,7 +42,7 @@ def benchmark_config(
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
num_iters
:
int
=
100
,
block_quant_shape
:
L
ist
[
int
]
=
None
,
block_quant_shape
:
l
ist
[
int
]
=
None
,
use_deep_gemm
:
bool
=
False
,
)
->
float
:
init_dtype
=
torch
.
float16
if
use_fp8_w8a8
else
dtype
...
...
@@ -400,7 +399,7 @@ class BenchmarkWorker:
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
block_quant_shape
:
L
ist
[
int
]
=
None
,
block_quant_shape
:
l
ist
[
int
]
=
None
,
use_deep_gemm
:
bool
=
False
,
)
->
tuple
[
dict
[
str
,
int
],
float
]:
current_platform
.
seed_everything
(
self
.
seed
)
...
...
@@ -532,7 +531,7 @@ def save_configs(
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
block_quant_shape
:
L
ist
[
int
],
block_quant_shape
:
l
ist
[
int
],
)
->
None
:
dtype_str
=
get_config_dtype_str
(
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
...
...
@@ -563,7 +562,6 @@ def main(args: argparse.Namespace):
config
=
get_config
(
model
=
args
.
model
,
trust_remote_code
=
args
.
trust_remote_code
)
if
args
.
model_prefix
:
config
=
getattr
(
config
,
args
.
model_prefix
)
config
=
SimpleNamespace
(
**
config
)
if
config
.
architectures
[
0
]
==
"DbrxForCausalLM"
:
E
=
config
.
ffn_config
.
moe_num_experts
...
...
@@ -595,11 +593,7 @@ def main(args: argparse.Namespace):
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
hidden_size
=
config
.
hidden_size
dtype
=
(
torch
.
float16
if
current_platform
.
is_rocm
()
else
getattr
(
torch
,
config
.
torch_dtype
)
)
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
torch_dtype
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
block_quant_shape
=
get_weight_block_size_safety
(
config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment