Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3a434b07
Unverified
Commit
3a434b07
authored
Jun 03, 2024
by
Woosuk Kwon
Committed by
GitHub
Jun 03, 2024
Browse files
[Kernel] Enhance MoE benchmarking & tuning script (#4921)
parent
bd0e7802
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
346 additions
and
253 deletions
+346
-253
benchmarks/kernels/benchmark_mixtral_moe.py
benchmarks/kernels/benchmark_mixtral_moe.py
+0
-239
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+319
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+27
-14
No files found.
benchmarks/kernels/benchmark_mixtral_moe.py
deleted
100644 → 0
View file @
bd0e7802
import
argparse
import
json
import
os
import
sys
import
torch
import
torch.nn.functional
as
F
import
triton
from
tqdm
import
tqdm
from
vllm.model_executor.layers.fused_moe
import
(
fused_moe
,
get_config_file_name
)
def
main
(
model
,
tp_size
,
gpu
,
dtype
:
str
):
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
str
(
gpu
)
method
=
fused_moe
for
bs
in
[
1
,
2
,
4
,
8
,
16
,
24
,
32
,
48
,
64
,
96
,
128
,
256
,
512
,
1024
,
1536
,
2048
,
3072
,
4096
]:
run_grid
(
bs
,
model
=
model
,
method
=
method
,
gpu
=
gpu
,
tp_size
=
tp_size
,
dtype
=
dtype
)
def
run_grid
(
bs
,
model
,
method
,
gpu
,
tp_size
,
dtype
:
str
):
if
model
==
'8x7B'
:
d_model
=
4096
model_intermediate_size
=
14336
num_layers
=
32
elif
model
==
'8x22B'
:
d_model
=
6144
model_intermediate_size
=
16384
num_layers
=
56
else
:
raise
ValueError
(
f
'Unsupported Mixtral model
{
model
}
'
)
num_total_experts
=
8
top_k
=
2
# tp_size = 2
num_calls
=
100
num_warmup_trials
=
1
num_trials
=
1
configs
=
[]
for
block_size_n
in
[
32
,
64
,
128
,
256
]:
for
block_size_m
in
[
16
,
32
,
64
,
128
,
256
]:
for
block_size_k
in
[
64
,
128
,
256
]:
for
group_size_m
in
[
1
,
16
,
32
,
64
]:
for
num_warps
in
[
4
,
8
]:
for
num_stages
in
[
2
,
3
,
4
,
5
]:
configs
.
append
({
"BLOCK_SIZE_M"
:
block_size_m
,
"BLOCK_SIZE_N"
:
block_size_n
,
"BLOCK_SIZE_K"
:
block_size_k
,
"GROUP_SIZE_M"
:
group_size_m
,
"num_warps"
:
num_warps
,
"num_stages"
:
num_stages
,
})
best_config
=
None
best_time_us
=
1e20
print
(
f
'
{
tp_size
=
}
{
bs
=
}
'
)
for
config
in
tqdm
(
configs
):
# warmup
try
:
for
_
in
range
(
num_warmup_trials
):
run_timing
(
num_calls
=
num_calls
,
bs
=
bs
,
d_model
=
d_model
,
num_total_experts
=
num_total_experts
,
top_k
=
top_k
,
tp_size
=
tp_size
,
model_intermediate_size
=
model_intermediate_size
,
method
=
method
,
config
=
config
,
dtype
=
dtype
,
)
except
triton
.
runtime
.
autotuner
.
OutOfResources
:
continue
# trial
for
_
in
range
(
num_trials
):
kernel_dur_ms
=
run_timing
(
num_calls
=
num_calls
,
bs
=
bs
,
d_model
=
d_model
,
num_total_experts
=
num_total_experts
,
top_k
=
top_k
,
tp_size
=
tp_size
,
model_intermediate_size
=
model_intermediate_size
,
method
=
method
,
config
=
config
,
dtype
=
dtype
,
)
kernel_dur_us
=
1000
*
kernel_dur_ms
model_dur_ms
=
kernel_dur_ms
*
num_layers
if
kernel_dur_us
<
best_time_us
:
best_config
=
config
best_time_us
=
kernel_dur_us
tqdm
.
write
(
f
'
{
kernel_dur_us
=
:.
1
f
}
{
model_dur_ms
=
:.
1
f
}
'
f
'
{
bs
=
}
{
tp_size
=
}
{
top_k
=
}
{
num_total_experts
=
}
'
f
'
{
d_model
=
}
{
model_intermediate_size
=
}
{
num_layers
=
}
'
)
print
(
"best_time_us"
,
best_time_us
)
print
(
"best_config"
,
best_config
)
# holds Dict[str, Dict[str, int]]
filename
=
get_config_file_name
(
num_total_experts
,
model_intermediate_size
//
tp_size
,
"float8"
if
dtype
==
"float8"
else
None
)
print
(
f
"writing config to file
{
filename
}
"
)
existing_content
=
{}
if
os
.
path
.
exists
(
filename
):
with
open
(
filename
,
"r"
)
as
f
:
existing_content
=
json
.
load
(
f
)
existing_content
[
str
(
bs
)]
=
best_config
with
open
(
filename
,
"w"
)
as
f
:
json
.
dump
(
existing_content
,
f
,
indent
=
4
)
f
.
write
(
"
\n
"
)
def
run_timing
(
num_calls
:
int
,
bs
:
int
,
d_model
:
int
,
num_total_experts
:
int
,
top_k
:
int
,
tp_size
:
int
,
model_intermediate_size
:
int
,
method
,
config
,
dtype
:
str
)
->
float
:
shard_intermediate_size
=
model_intermediate_size
//
tp_size
hidden_states
=
torch
.
rand
(
(
bs
,
d_model
),
device
=
"cuda:0"
,
dtype
=
torch
.
float16
,
)
w1
=
torch
.
rand
(
(
num_total_experts
,
2
*
shard_intermediate_size
,
d_model
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
w2
=
torch
.
rand
(
(
num_total_experts
,
d_model
,
shard_intermediate_size
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
w1_scale
=
None
w2_scale
=
None
a1_scale
=
None
a2_scale
=
None
if
dtype
==
"float8"
:
w1
=
w1
.
to
(
torch
.
float8_e4m3fn
)
w2
=
w2
.
to
(
torch
.
float8_e4m3fn
)
w1_scale
=
torch
.
ones
(
num_total_experts
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
ones
(
num_total_experts
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
float32
)
a1_scale
=
torch
.
ones
(
1
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
float32
)
a2_scale
=
torch
.
ones
(
1
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
float32
)
gating_output
=
F
.
softmax
(
torch
.
rand
(
(
num_calls
,
bs
,
num_total_experts
),
device
=
hidden_states
.
device
,
dtype
=
torch
.
float32
,
),
dim
=-
1
)
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
start_event
.
record
()
for
i
in
range
(
num_calls
):
hidden_states
=
method
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
gating_output
=
gating_output
[
i
],
topk
=
2
,
renormalize
=
True
,
inplace
=
True
,
override_config
=
config
,
use_fp8
=
dtype
==
"float8"
,
)
end_event
.
record
()
end_event
.
synchronize
()
dur_ms
=
start_event
.
elapsed_time
(
end_event
)
/
num_calls
return
dur_ms
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
prog
=
'benchmark_mixtral_moe'
,
description
=
'Benchmark and tune the fused_moe kernel'
,
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'auto'
,
choices
=
[
'float8'
,
'float16'
],
help
=
'Data type used for fused_moe kernel computations'
,
)
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'8x7B'
,
choices
=
[
'8x7B'
,
'8x22B'
],
help
=
'The Mixtral model to benchmark'
)
parser
.
add_argument
(
'--tp-size'
,
type
=
int
,
default
=
2
,
help
=
'Tensor paralleli size'
)
parser
.
add_argument
(
'--gpu'
,
type
=
int
,
default
=
0
,
help
=
"GPU ID for benchmarking"
)
args
=
parser
.
parse_args
()
sys
.
exit
(
main
(
args
.
model
,
args
.
tp_size
,
args
.
gpu
,
args
.
dtype
))
benchmarks/kernels/benchmark_moe.py
0 → 100644
View file @
3a434b07
import
argparse
import
time
from
datetime
import
datetime
from
typing
import
Any
,
Dict
,
List
,
Tuple
import
ray
import
torch
import
triton
from
ray.experimental.tqdm_ray
import
tqdm
from
transformers
import
AutoConfig
from
vllm.model_executor.layers.fused_moe.fused_moe
import
*
def
benchmark_config
(
config
:
Dict
[
str
,
int
],
num_tokens
:
int
,
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8
:
bool
,
num_iters
:
int
=
100
,
)
->
float
:
init_dtype
=
torch
.
float16
if
use_fp8
else
dtype
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
w1
=
torch
.
randn
(
num_experts
,
shard_intermediate_size
,
hidden_size
,
dtype
=
init_dtype
)
w2
=
torch
.
randn
(
num_experts
,
hidden_size
,
shard_intermediate_size
//
2
,
dtype
=
init_dtype
)
gating_output
=
torch
.
randn
(
num_iters
,
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
w1_scale
=
None
w2_scale
=
None
a1_scale
=
None
a2_scale
=
None
if
use_fp8
:
w1_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
a1_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
a2_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
w1
=
w1
.
to
(
torch
.
float8_e4m3fn
)
w2
=
w2
.
to
(
torch
.
float8_e4m3fn
)
input_gating
=
torch
.
empty
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
def
prepare
(
i
:
int
):
input_gating
.
copy_
(
gating_output
[
i
])
def
run
():
fused_moe
(
x
,
w1
,
w2
,
input_gating
,
topk
,
renormalize
=
True
,
inplace
=
True
,
override_config
=
config
,
use_fp8
=
use_fp8
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
)
# JIT compilation & warmup
run
()
torch
.
cuda
.
synchronize
()
# Capture 10 invocations with CUDA graph
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
for
_
in
range
(
10
):
run
()
torch
.
cuda
.
synchronize
()
# Warmup
for
_
in
range
(
5
):
graph
.
replay
()
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
latencies
=
[]
for
i
in
range
(
num_iters
):
prepare
(
i
)
torch
.
cuda
.
synchronize
()
start_event
.
record
()
graph
.
replay
()
end_event
.
record
()
end_event
.
synchronize
()
latencies
.
append
(
start_event
.
elapsed_time
(
end_event
))
avg
=
sum
(
latencies
)
/
(
num_iters
*
10
)
*
1000
# us
graph
.
reset
()
return
avg
def
get_configs_compute_bound
()
->
List
[
Dict
[
str
,
int
]]:
# Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to
# prune the search space.
configs
=
[]
for
num_stages
in
[
2
,
3
,
4
,
5
]:
for
block_m
in
[
16
,
32
,
64
,
128
,
256
]:
for
block_k
in
[
64
,
128
,
256
]:
for
block_n
in
[
32
,
64
,
128
,
256
]:
for
num_warps
in
[
4
,
8
]:
for
group_size
in
[
1
,
16
,
32
,
64
]:
configs
.
append
({
"BLOCK_SIZE_M"
:
block_m
,
"BLOCK_SIZE_N"
:
block_n
,
"BLOCK_SIZE_K"
:
block_k
,
"GROUP_SIZE_M"
:
group_size
,
"num_warps"
:
num_warps
,
"num_stages"
:
num_stages
,
})
return
configs
@
ray
.
remote
(
num_gpus
=
1
)
class
BenchmarkWorker
:
def
__init__
(
self
,
seed
:
int
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
seed
)
self
.
seed
=
seed
def
benchmark
(
self
,
num_tokens
:
int
,
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8
:
bool
,
)
->
Tuple
[
Dict
[
str
,
int
],
float
]:
torch
.
cuda
.
manual_seed_all
(
self
.
seed
)
dtype_str
=
"float8"
if
use_fp8
else
None
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
op_config
=
get_moe_configs
(
num_experts
,
shard_intermediate_size
//
2
,
dtype_str
)
if
op_config
is
None
:
config
=
get_default_config
(
num_tokens
,
num_experts
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype_str
)
else
:
config
=
op_config
[
min
(
op_config
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
num_tokens
))]
kernel_time
=
benchmark_config
(
config
,
num_tokens
,
num_experts
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8
)
return
config
,
kernel_time
def
tune
(
self
,
num_tokens
:
int
,
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8
:
bool
,
search_space
:
List
[
Dict
[
str
,
int
]],
)
->
Dict
[
str
,
int
]:
best_config
=
None
best_time
=
float
(
"inf"
)
for
config
in
tqdm
(
search_space
):
try
:
kernel_time
=
benchmark_config
(
config
,
num_tokens
,
num_experts
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8
,
num_iters
=
10
)
except
triton
.
runtime
.
autotuner
.
OutOfResources
:
# Some configurations may be invalid and fail to compile.
continue
if
kernel_time
<
best_time
:
best_time
=
kernel_time
best_config
=
config
now
=
datetime
.
now
()
print
(
f
"
{
now
.
ctime
()
}
] Completed tuning for batch_size=
{
num_tokens
}
"
)
return
best_config
def
sort_config
(
config
:
Dict
[
str
,
int
])
->
Dict
[
str
,
int
]:
return
{
"BLOCK_SIZE_M"
:
config
[
"BLOCK_SIZE_M"
],
"BLOCK_SIZE_N"
:
config
[
"BLOCK_SIZE_N"
],
"BLOCK_SIZE_K"
:
config
[
"BLOCK_SIZE_K"
],
"GROUP_SIZE_M"
:
config
[
"GROUP_SIZE_M"
],
"num_warps"
:
config
[
"num_warps"
],
"num_stages"
:
config
[
"num_stages"
],
}
def
save_configs
(
configs
:
Dict
[
int
,
Dict
[
str
,
int
]],
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8
:
bool
,
)
->
None
:
dtype_str
=
"float8"
if
use_fp8
else
None
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename
=
get_config_file_name
(
num_experts
,
shard_intermediate_size
//
2
,
dtype_str
)
print
(
f
"Writing best config to
{
filename
}
..."
)
with
open
(
filename
,
"w"
)
as
f
:
json
.
dump
(
configs
,
f
,
indent
=
4
)
f
.
write
(
"
\n
"
)
def
main
(
args
:
argparse
.
Namespace
):
print
(
args
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model
)
if
config
.
architectures
[
0
]
==
"DbrxForCausalLM"
:
E
=
config
.
ffn_config
.
moe_num_experts
topk
=
config
.
ffn_config
.
moe_top_k
intermediate_size
=
config
.
ffn_config
.
ffn_hidden_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
else
:
# Default: Mixtral.
E
=
config
.
num_local_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
hidden_size
=
config
.
hidden_size
dtype
=
config
.
torch_dtype
use_fp8
=
args
.
dtype
==
"fp8"
if
args
.
batch_size
is
None
:
batch_sizes
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
]
else
:
batch_sizes
=
[
args
.
batch_size
]
ray
.
init
()
num_gpus
=
int
(
ray
.
available_resources
()[
"GPU"
])
workers
=
[
BenchmarkWorker
.
remote
(
args
.
seed
)
for
_
in
range
(
num_gpus
)]
def
_distribute
(
method
:
str
,
inputs
:
List
[
Any
])
->
List
[
Any
]:
outputs
=
[]
worker_idx
=
0
for
input_args
in
inputs
:
worker
=
workers
[
worker_idx
]
worker_method
=
getattr
(
worker
,
method
)
output
=
worker_method
.
remote
(
*
input_args
)
outputs
.
append
(
output
)
worker_idx
=
(
worker_idx
+
1
)
%
num_gpus
return
ray
.
get
(
outputs
)
if
args
.
tune
:
search_space
=
get_configs_compute_bound
()
print
(
f
"Start tuning over
{
len
(
search_space
)
}
configurations..."
)
start
=
time
.
time
()
configs
=
_distribute
(
"tune"
,
[(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8
,
search_space
)
for
batch_size
in
batch_sizes
])
best_configs
=
{
M
:
sort_config
(
config
)
for
M
,
config
in
zip
(
batch_sizes
,
configs
)
}
save_configs
(
best_configs
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8
)
end
=
time
.
time
()
print
(
f
"Tuning took
{
end
-
start
:.
2
f
}
seconds"
)
else
:
outputs
=
_distribute
(
"benchmark"
,
[(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8
)
for
batch_size
in
batch_sizes
])
for
batch_size
,
(
config
,
kernel_time
)
in
zip
(
batch_sizes
,
outputs
):
print
(
f
"Batch size:
{
batch_size
}
, config:
{
config
}
"
)
print
(
f
"Kernel time:
{
kernel_time
:.
2
f
}
us"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser
.
add_argument
(
"--tp-size"
,
"-tp"
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"auto"
,
"fp8"
],
default
=
"auto"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--tune"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
main
(
args
)
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
3a434b07
...
...
@@ -308,6 +308,30 @@ def get_moe_configs(E: int, N: int,
return
None
def
get_default_config
(
M
:
int
,
E
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
dtype
:
Optional
[
str
],
)
->
Dict
[
str
,
int
]:
config
=
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
,
'GROUP_SIZE_M'
:
8
}
if
M
<=
E
:
config
=
{
'BLOCK_SIZE_M'
:
16
,
'BLOCK_SIZE_N'
:
32
,
'BLOCK_SIZE_K'
:
64
,
'GROUP_SIZE_M'
:
1
}
return
config
def
fused_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
...
...
@@ -382,20 +406,9 @@ def fused_experts(hidden_states: torch.Tensor,
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
# Else use the default config
config
=
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
,
'GROUP_SIZE_M'
:
8
}
if
M
<=
E
:
config
=
{
'BLOCK_SIZE_M'
:
16
,
'BLOCK_SIZE_N'
:
32
,
'BLOCK_SIZE_K'
:
64
,
'GROUP_SIZE_M'
:
1
}
config
=
get_default_config
(
M
,
E
,
N
,
w1
.
shape
[
2
],
topk_ids
.
shape
[
1
],
"float8"
if
use_fp8
else
None
)
intermediate_cache1
=
torch
.
empty
((
M
,
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment