Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a73122de
Unverified
Commit
a73122de
authored
Mar 13, 2025
by
Jee Jee Li
Committed by
GitHub
Mar 13, 2025
Browse files
[Bugfix] fix benchmark moe (#14653)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
bd44b812
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
5 deletions
+21
-5
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+21
-5
No files found.
benchmarks/kernels/benchmark_moe.py
View file @
a73122de
...
@@ -365,6 +365,7 @@ class BenchmarkWorker:
...
@@ -365,6 +365,7 @@ class BenchmarkWorker:
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int8_w8a16
:
bool
,
block_quant_shape
:
List
[
int
]
=
None
,
)
->
tuple
[
dict
[
str
,
int
],
float
]:
)
->
tuple
[
dict
[
str
,
int
],
float
]:
current_platform
.
seed_everything
(
self
.
seed
)
current_platform
.
seed_everything
(
self
.
seed
)
dtype_str
=
get_config_dtype_str
(
dtype
,
dtype_str
=
get_config_dtype_str
(
dtype
,
...
@@ -385,10 +386,17 @@ class BenchmarkWorker:
...
@@ -385,10 +386,17 @@ class BenchmarkWorker:
else
:
else
:
config
=
op_config
[
min
(
op_config
.
keys
(),
config
=
op_config
[
min
(
op_config
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
num_tokens
))]
key
=
lambda
x
:
abs
(
x
-
num_tokens
))]
kernel_time
=
benchmark_config
(
config
,
num_tokens
,
num_experts
,
kernel_time
=
benchmark_config
(
config
,
shard_intermediate_size
,
hidden_size
,
num_tokens
,
topk
,
dtype
,
use_fp8_w8a8
,
num_experts
,
use_int8_w8a16
)
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
num_iters
=
100
,
block_quant_shape
=
block_quant_shape
)
return
config
,
kernel_time
return
config
,
kernel_time
def
tune
(
def
tune
(
...
@@ -487,6 +495,14 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
...
@@ -487,6 +495,14 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
f
.
write
(
"
\n
"
)
f
.
write
(
"
\n
"
)
def
get_weight_block_size_safety
(
config
,
default_value
=
None
):
quantization_config
=
getattr
(
config
,
'quantization_config'
,
{})
if
isinstance
(
quantization_config
,
dict
):
return
quantization_config
.
get
(
'weight_block_size'
,
default_value
)
return
default_value
def
main
(
args
:
argparse
.
Namespace
):
def
main
(
args
:
argparse
.
Namespace
):
print
(
args
)
print
(
args
)
block_quant_shape
=
None
block_quant_shape
=
None
...
@@ -508,7 +524,7 @@ def main(args: argparse.Namespace):
...
@@ -508,7 +524,7 @@ def main(args: argparse.Namespace):
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
block_quant_shape
=
config
.
quantization_config
[
'
weight_block_size
'
]
block_quant_shape
=
get_
weight_block_size
_safety
(
config
)
elif
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
:
elif
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
:
E
=
config
.
num_experts
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment