Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
849f58d6
"vscode:/vscode.git/clone" did not exist on "e6d2d04121fdd89b424ae57cc64511862f50f31a"
Unverified
Commit
849f58d6
authored
Feb 08, 2025
by
GaoYuYang
Committed by
GitHub
Feb 08, 2025
Browse files
Update fused_moe's benchmark (#3346)
parent
64480df4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
75 additions
and
22 deletions
+75
-22
benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py
...d_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py
+75
-22
No files found.
benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py
View file @
849f58d6
...
...
@@ -2,6 +2,7 @@ import argparse
import
torch
import
triton
import
vllm
from
transformers
import
AutoConfig
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_moe
as
fused_moe_vllm
...
...
@@ -29,11 +30,11 @@ def get_model_config(model_name: str, tp_size: int):
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
elif
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
:
elif
config
.
architectures
[
0
]
in
[
"DeepseekV2ForCausalLM"
,
"DeepseekV3ForCausalLM"
]
:
E
=
config
.
n_routed_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
intermediate_size
=
config
.
moe_
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
else
:
# Default: Mixtral
E
=
config
.
num_local_experts
...
...
@@ -41,12 +42,27 @@ def get_model_config(model_name: str, tp_size: int):
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
vllm_version_num
=
(
vllm
.
__version_tuple__
[
0
]
*
100
+
vllm
.
__version_tuple__
[
1
]
*
10
+
vllm
.
__version_tuple__
[
2
]
)
block_shape
=
None
if
(
hasattr
(
config
,
"quantization_config"
)
and
"weight_block_size"
in
config
.
quantization_config
):
block_shape
=
config
.
quantization_config
[
"weight_block_size"
]
assert
len
(
block_shape
)
==
2
assert
vllm_version_num
>=
66
,
"Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1"
shape_configs
=
{
"num_experts"
:
E
,
"topk"
:
topk
,
"hidden_size"
:
config
.
hidden_size
,
"shard_intermediate_size"
:
shard_intermediate_size
,
"dtype"
:
config
.
torch_dtype
,
"block_shape"
:
block_shape
,
}
print
(
f
"
{
shape_configs
=
}
"
)
return
shape_configs
...
...
@@ -63,21 +79,39 @@ def fused_moe_vllm_api(
w2_scale
=
None
,
a1_scale
=
None
,
a2_scale
=
None
,
block_shape
=
None
,
):
return
fused_moe_vllm
(
x
,
w1
,
w2
,
input_gating
,
topk
,
renormalize
=
True
,
inplace
=
True
,
use_fp8_w8a8
=
use_fp8_w8a8
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
)
if
block_shape
is
not
None
:
return
fused_moe_vllm
(
x
,
w1
,
w2
,
input_gating
,
topk
,
renormalize
=
True
,
inplace
=
True
,
use_fp8_w8a8
=
use_fp8_w8a8
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
)
else
:
return
fused_moe_vllm
(
x
,
w1
,
w2
,
input_gating
,
topk
,
renormalize
=
True
,
inplace
=
True
,
use_fp8_w8a8
=
use_fp8_w8a8
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
)
def
fused_moe_sglang_api
(
...
...
@@ -91,6 +125,7 @@ def fused_moe_sglang_api(
w2_scale
=
None
,
a1_scale
=
None
,
a2_scale
=
None
,
block_shape
=
None
,
):
return
fused_moe_sglang
(
x
,
...
...
@@ -105,6 +140,7 @@ def fused_moe_sglang_api(
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
)
...
...
@@ -141,8 +177,10 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
shard_intermediate_size
=
model_config
[
"shard_intermediate_size"
]
topk
=
model_config
[
"topk"
]
dtype
=
model_config
[
"dtype"
]
block_shape
=
getattr
(
model_config
,
"block_shape"
,
None
)
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
w1_scale
=
w2_scale
=
a1_scale
=
a2_scale
=
None
if
use_fp8
:
init_dtype
=
dtype
...
...
@@ -154,16 +192,29 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
)
w1
=
w1
.
to
(
torch
.
float8_e4m3fn
)
w2
=
w2
.
to
(
torch
.
float8_e4m3fn
)
w1_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
a1_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
a2_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
if
block_shape
is
None
:
w1_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
a1_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
a2_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
else
:
block_n
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
n_tiles_w1
=
(
shard_intermediate_size
+
block_n
-
1
)
//
block_n
n_tiles_w2
=
(
hidden_size
+
block_n
-
1
)
//
block_n
k_tiles_w1
=
(
hidden_size
+
block_k
-
1
)
//
block_k
k_tiles_w2
=
(
shard_intermediate_size
//
2
+
block_k
-
1
)
//
block_k
w1_scale
=
torch
.
rand
(
(
num_experts
,
n_tiles_w1
,
k_tiles_w1
),
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
rand
(
(
num_experts
,
n_tiles_w2
,
k_tiles_w2
),
dtype
=
torch
.
float32
)
else
:
w1
=
torch
.
randn
(
num_experts
,
shard_intermediate_size
,
hidden_size
,
dtype
=
dtype
)
w2
=
torch
.
randn
(
num_experts
,
hidden_size
,
shard_intermediate_size
//
2
,
dtype
=
dtype
)
w1_scale
=
w2_scale
=
a1_scale
=
a2_scale
=
None
input_gating
=
torch
.
randn
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
...
...
@@ -185,6 +236,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
)
torch
.
cuda
.
synchronize
()
...
...
@@ -201,6 +253,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
)[
0
],
quantiles
=
quantiles
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment