Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
33deca81
"tests/python/common/test_batch-heterograph.py" did not exist on "1425150459963514047ac3a7ce84574eaf463a2b"
Unverified
Commit
33deca81
authored
Dec 02, 2024
by
Lianmin Zheng
Committed by
GitHub
Dec 02, 2024
Browse files
Add more fused moe benchmark utilities (#2314)
parent
18108abe
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
294 additions
and
23 deletions
+294
-23
benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py
...els/fused_moe_triton/benchmark_torch_compile_fused_moe.py
+275
-0
benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py
...d_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py
+6
-15
benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
...hmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
+13
-8
No files found.
benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py
0 → 100644
View file @
33deca81
import
argparse
import
torch
import
triton
from
torch.nn
import
functional
as
F
from
transformers
import
AutoConfig
from
sglang.srt.layers.fused_moe_triton.fused_moe
import
fused_moe
as
fused_moe_triton
def
get_model_config
(
model_name
:
str
,
tp_size
:
int
):
"""Get model configuration parameters"""
config
=
AutoConfig
.
from_pretrained
(
model_name
)
if
config
.
architectures
[
0
]
==
"DbrxForCausalLM"
:
E
=
config
.
ffn_config
.
moe_num_experts
topk
=
config
.
ffn_config
.
moe_top_k
intermediate_size
=
config
.
ffn_config
.
ffn_hidden_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
elif
config
.
architectures
[
0
]
==
"JambaForCausalLM"
:
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
elif
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
:
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
else
:
# Default: Mixtral
E
=
config
.
num_local_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
shape_configs
=
{
"num_experts"
:
E
,
"topk"
:
topk
,
"hidden_size"
:
config
.
hidden_size
,
"shard_intermediate_size"
:
shard_intermediate_size
,
"dtype"
:
config
.
torch_dtype
,
}
print
(
f
"
{
shape_configs
=
}
"
)
return
shape_configs
def
fused_topk_native
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
M
,
_
=
hidden_states
.
shape
topk_weights
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
topk_ids
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
topk_weights
=
F
.
softmax
(
gating_output
.
float
(),
dim
=-
1
)
topk_weights
,
topk_ids
=
torch
.
topk
(
topk_weights
,
topk
,
dim
=-
1
)
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
,
topk_ids
@
torch
.
compile
def
fused_moe_torch
(
x
,
w1
,
w2
,
input_gating
,
topk
,
use_fp8_w8a8
=
False
,
w1_scale
=
None
,
w2_scale
=
None
,
a1_scale
=
None
,
a2_scale
=
None
,
)
->
torch
.
Tensor
:
assert
not
use_fp8_w8a8
,
"Not supported"
topk_weights
,
topk_ids
=
fused_topk_native
(
hidden_states
=
x
,
gating_output
=
input_gating
,
topk
=
topk
,
renormalize
=
True
,
)
w13_weights
=
w1
[
topk_ids
]
w1_weights
,
w3_weights
=
torch
.
chunk
(
w13_weights
,
2
,
dim
=
2
)
w2_weights
=
w2
[
topk_ids
]
x1
=
F
.
gelu
(
torch
.
einsum
(
"ti,taoi -> tao"
,
x
,
w1_weights
))
x3
=
torch
.
einsum
(
"ti, taoi -> tao"
,
x
,
w3_weights
)
expert_outs
=
torch
.
einsum
(
"tao, taio -> tai"
,
(
x1
*
x3
),
w2_weights
)
return
torch
.
einsum
(
"tai,ta -> ti"
,
expert_outs
,
topk_weights
.
to
(
expert_outs
.
dtype
))
def
fused_moe_torch_compile
(
x
,
w1
,
w2
,
input_gating
,
topk
,
use_fp8_w8a8
=
False
,
w1_scale
=
None
,
w2_scale
=
None
,
a1_scale
=
None
,
a2_scale
=
None
,
):
return
fused_moe_torch
(
x
,
w1
,
w2
,
input_gating
,
topk
,
use_fp8_w8a8
=
use_fp8_w8a8
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
)
def
fused_moe_sglang_api
(
x
,
w1
,
w2
,
input_gating
,
topk
,
use_fp8_w8a8
=
False
,
w1_scale
=
None
,
w2_scale
=
None
,
a1_scale
=
None
,
a2_scale
=
None
,
):
return
fused_moe_triton
(
x
,
w1
,
w2
,
input_gating
,
topk
,
renormalize
=
True
,
inplace
=
True
,
use_fp8_w8a8
=
use_fp8_w8a8
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
)
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
list
(
range
(
1
,
5
)),
line_arg
=
"provider"
,
line_vals
=
[
"fused_moe_triton"
,
"fused_moe_torch_compile"
,
],
line_names
=
[
"fused_moe_triton"
,
"fused_moe_torch_compile"
,
],
styles
=
[
(
"blue"
,
"-"
),
(
"green"
,
"-"
),
],
ylabel
=
"Time (ms)"
,
plot_name
=
"fused-moe-performance"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
model_config
,
use_fp8
=
False
):
print
(
f
"benchmark
{
provider
}
with batch_size=
{
batch_size
}
"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
num_tokens
=
batch_size
num_experts
=
model_config
[
"num_experts"
]
hidden_size
=
model_config
[
"hidden_size"
]
shard_intermediate_size
=
model_config
[
"shard_intermediate_size"
]
topk
=
model_config
[
"topk"
]
dtype
=
model_config
[
"dtype"
]
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
if
use_fp8
:
init_dtype
=
dtype
w1
=
torch
.
randn
(
num_experts
,
shard_intermediate_size
,
hidden_size
,
dtype
=
init_dtype
)
w2
=
torch
.
randn
(
num_experts
,
hidden_size
,
shard_intermediate_size
//
2
,
dtype
=
init_dtype
)
w1
=
w1
.
to
(
torch
.
float8_e4m3fn
)
w2
=
w2
.
to
(
torch
.
float8_e4m3fn
)
w1_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
a1_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
a2_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
else
:
w1
=
torch
.
randn
(
num_experts
,
shard_intermediate_size
,
hidden_size
,
dtype
=
dtype
)
w2
=
torch
.
randn
(
num_experts
,
hidden_size
,
shard_intermediate_size
//
2
,
dtype
=
dtype
)
w1_scale
=
w2_scale
=
a1_scale
=
a2_scale
=
None
input_gating
=
torch
.
randn
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
# Warmup
api_func
=
(
fused_moe_torch_compile
if
provider
==
"fused_moe_torch_compile"
else
fused_moe_sglang_api
)
for
_
in
range
(
10
):
y
=
api_func
(
x
,
w1
,
w2
,
input_gating
,
topk
,
use_fp8_w8a8
=
use_fp8
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
)
torch
.
cuda
.
synchronize
()
quantiles
=
[
0.5
,
0.2
,
0.8
]
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
api_func
(
x
,
w1
,
w2
,
input_gating
,
topk
,
use_fp8_w8a8
=
use_fp8
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
)[
0
],
quantiles
=
quantiles
,
)
return
ms
,
min_ms
,
max_ms
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser
.
add_argument
(
"--tp-size"
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--use-fp8"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--save-path"
,
type
=
str
,
default
=
"./configs/benchmark_ops/fused_moe_torch_compile/"
,
)
args
=
parser
.
parse_args
()
model_config
=
get_model_config
(
args
.
model
,
args
.
tp_size
)
benchmark
.
run
(
show_plots
=
True
,
print_data
=
True
,
save_path
=
args
.
save_path
,
model_config
=
model_config
,
use_fp8
=
args
.
use_fp8
,
)
if
__name__
==
"__main__"
:
main
()
benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py
View file @
33deca81
import
argparse
import
argparse
import
numbers
from
typing
import
Optional
import
torch
import
torch
import
triton
import
triton
from
torch.nn
import
init
from
torch.nn.parameter
import
Parameter
from
transformers
import
AutoConfig
from
transformers
import
AutoConfig
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_moe
as
fused_moe_vllm
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_moe
as
fused_moe_vllm
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
get_moe_configs
as
get_moe_configs_vllm
,
)
from
vllm.utils
import
FlexibleArgumentParser
from
sglang.srt.layers.fused_moe_triton.fused_moe
import
fused_moe
as
fused_moe_sglang
from
sglang.srt.layers.fused_moe_triton.fused_moe
import
fused_moe
as
fused_moe_sglang
from
sglang.srt.layers.fused_moe_triton.fused_moe
import
(
get_moe_configs
as
get_moe_configs_sglang
,
)
def
get_model_config
(
model_name
:
str
,
tp_size
:
int
):
def
get_model_config
(
model_name
:
str
,
tp_size
:
int
):
...
@@ -39,19 +28,21 @@ def get_model_config(model_name: str, tp_size: int):
...
@@ -39,19 +28,21 @@ def get_model_config(model_name: str, tp_size: int):
intermediate_size
=
config
.
moe_intermediate_size
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
else
:
else
:
# Default: Mixtral
, Grok1, etc.
# Default: Mixtral
E
=
config
.
num_local_experts
E
=
config
.
num_local_experts
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
return
{
shape_configs
=
{
"num_experts"
:
E
,
"num_experts"
:
E
,
"topk"
:
topk
,
"topk"
:
topk
,
"hidden_size"
:
config
.
hidden_size
,
"hidden_size"
:
config
.
hidden_size
,
"shard_intermediate_size"
:
shard_intermediate_size
,
"shard_intermediate_size"
:
shard_intermediate_size
,
"dtype"
:
config
.
torch_dtype
,
"dtype"
:
config
.
torch_dtype
,
}
}
print
(
f
"
{
shape_configs
=
}
"
)
return
shape_configs
def
fused_moe_vllm_api
(
def
fused_moe_vllm_api
(
...
@@ -133,7 +124,7 @@ def fused_moe_sglang_api(
...
@@ -133,7 +124,7 @@ def fused_moe_sglang_api(
)
)
)
)
def
benchmark
(
batch_size
,
provider
,
model_config
,
use_fp8
=
False
):
def
benchmark
(
batch_size
,
provider
,
model_config
,
use_fp8
=
False
):
print
(
f
"benchmark
for
batch_size=
{
batch_size
}
"
)
print
(
f
"benchmark
{
provider
}
with
batch_size=
{
batch_size
}
"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
torch
.
cuda
.
manual_seed_all
(
0
)
...
@@ -210,7 +201,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
...
@@ -210,7 +201,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
def
main
():
def
main
():
parser
=
Flexible
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
"--model"
,
type
=
str
,
default
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
)
)
...
...
benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
View file @
33deca81
# Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py
# Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py
import
argparse
import
argparse
import
json
import
time
import
time
from
datetime
import
datetime
from
datetime
import
datetime
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
TypedDict
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
TypedDict
...
@@ -9,10 +10,14 @@ import torch
...
@@ -9,10 +10,14 @@ import torch
import
triton
import
triton
from
ray.experimental.tqdm_ray
import
tqdm
from
ray.experimental.tqdm_ray
import
tqdm
from
transformers
import
AutoConfig
from
transformers
import
AutoConfig
from
vllm.platforms
import
current_platform
from
vllm.utils
import
FlexibleArgumentParser
from
sglang.srt.layers.fused_moe_triton.fused_moe
import
*
from
sglang.srt.layers.fused_moe_triton.fused_moe
import
(
fused_moe
,
get_config_dtype_str
,
get_config_file_name
,
get_default_config
,
get_moe_configs
,
)
class
BenchmarkConfig
(
TypedDict
):
class
BenchmarkConfig
(
TypedDict
):
...
@@ -92,7 +97,7 @@ def benchmark_config(
...
@@ -92,7 +97,7 @@ def benchmark_config(
input_gating
.
copy_
(
gating_output
[
i
])
input_gating
.
copy_
(
gating_output
[
i
])
def
run
():
def
run
():
from
sglang.srt.layers.fused_moe_triton
.fused_moe
import
override_config
from
sglang.srt.layers.fused_moe_triton
import
override_config
with
override_config
(
config
):
with
override_config
(
config
):
fused_moe
(
fused_moe
(
...
@@ -174,7 +179,7 @@ class BenchmarkWorker:
...
@@ -174,7 +179,7 @@ class BenchmarkWorker:
def
__init__
(
self
,
seed
:
int
)
->
None
:
def
__init__
(
self
,
seed
:
int
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
seed
)
torch
.
cuda
.
manual_seed_all
(
0
)
self
.
seed
=
seed
self
.
seed
=
seed
def
benchmark
(
def
benchmark
(
...
@@ -188,7 +193,7 @@ class BenchmarkWorker:
...
@@ -188,7 +193,7 @@ class BenchmarkWorker:
use_fp8_w8a8
:
bool
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int8_w8a16
:
bool
,
)
->
Tuple
[
Dict
[
str
,
int
],
float
]:
)
->
Tuple
[
Dict
[
str
,
int
],
float
]:
current_platform
.
seed_everything
(
self
.
seed
)
torch
.
cuda
.
manual_seed_all
(
0
)
dtype_str
=
get_config_dtype_str
(
dtype_str
=
get_config_dtype_str
(
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
)
)
...
@@ -319,7 +324,7 @@ def main(args: argparse.Namespace):
...
@@ -319,7 +324,7 @@ def main(args: argparse.Namespace):
intermediate_size
=
config
.
moe_intermediate_size
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
else
:
else
:
# Default: Mixtral
.
# Default: Mixtral
E
=
config
.
num_local_experts
E
=
config
.
num_local_experts
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
intermediate_size
=
config
.
intermediate_size
...
@@ -430,7 +435,7 @@ def main(args: argparse.Namespace):
...
@@ -430,7 +435,7 @@ def main(args: argparse.Namespace):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
Flexible
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
"--model"
,
type
=
str
,
default
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment