Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8027a724
Unverified
Commit
8027a724
authored
Jan 17, 2025
by
Divakar Verma
Committed by
GitHub
Jan 17, 2025
Browse files
[ROCm][MoE] moe tuning support for rocm (#12049)
Signed-off-by:
Divakar Verma
<
divakar.verma@amd.com
>
parent
d75ab55f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
224 additions
and
48 deletions
+224
-48
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+224
-48
No files found.
benchmarks/kernels/benchmark_moe.py
View file @
8027a724
import
argparse
import
argparse
import
time
import
time
from
datetime
import
datetime
from
datetime
import
datetime
from
itertools
import
product
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
TypedDict
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
TypedDict
import
ray
import
ray
...
@@ -11,7 +12,10 @@ from transformers import AutoConfig
...
@@ -11,7 +12,10 @@ from transformers import AutoConfig
from
vllm.model_executor.layers.fused_moe.fused_moe
import
*
from
vllm.model_executor.layers.fused_moe.fused_moe
import
*
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
,
is_navi
FP8_DTYPE
=
torch
.
float8_e4m3fnuz
if
current_platform
.
is_rocm
(
)
and
not
is_navi
()
else
torch
.
float8_e4m3fn
class
BenchmarkConfig
(
TypedDict
):
class
BenchmarkConfig
(
TypedDict
):
...
@@ -80,8 +84,8 @@ def benchmark_config(
...
@@ -80,8 +84,8 @@ def benchmark_config(
a1_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
a1_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
a2_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
a2_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
w1
=
w1
.
to
(
torch
.
float8_e4m3fn
)
w1
=
w1
.
to
(
FP8_DTYPE
)
w2
=
w2
.
to
(
torch
.
float8_e4m3fn
)
w2
=
w2
.
to
(
FP8_DTYPE
)
input_gating
=
torch
.
empty
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
input_gating
=
torch
.
empty
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
...
@@ -141,28 +145,172 @@ def benchmark_config(
...
@@ -141,28 +145,172 @@ def benchmark_config(
return
avg
return
avg
def
get_configs_compute_bound
()
->
List
[
Dict
[
str
,
int
]]:
def
get_rocm_tuning_space
(
use_fp16
):
block_mn_range
=
[
16
,
32
,
64
,
128
,
256
]
block_k_range
=
[
16
,
32
,
64
,
128
,
256
]
if
not
use_fp16
:
block_k_range
.
remove
(
16
)
# BLOCK_K=16 not supported for fp8
num_warps_range
=
[
1
,
2
,
4
,
8
]
group_m_range
=
[
1
,
4
,
8
,
16
,
32
]
num_stage_range
=
[
2
]
waves_per_eu_range
=
[
0
]
matrix_instr_nonkdim_range
=
[
16
,
32
]
if
use_fp16
else
[]
kpack_range
=
[
1
,
2
]
if
use_fp16
else
[]
param_ranges
=
{
"BLOCK_SIZE_M"
:
block_mn_range
,
"BLOCK_SIZE_N"
:
block_mn_range
,
"BLOCK_SIZE_K"
:
block_k_range
,
"GROUP_SIZE_M"
:
group_m_range
,
"num_warps"
:
num_warps_range
,
"num_stages"
:
num_stage_range
,
"waves_per_eu"
:
waves_per_eu_range
,
}
if
use_fp16
:
param_ranges
[
"matrix_instr_nonkdim"
]
=
matrix_instr_nonkdim_range
param_ranges
[
"kpack"
]
=
kpack_range
return
param_ranges
def
get_configs_compute_bound
(
use_fp16
)
->
List
[
Dict
[
str
,
int
]]:
configs
:
List
[
BenchmarkConfig
]
=
[]
if
current_platform
.
is_rocm
():
param_ranges
=
get_rocm_tuning_space
(
use_fp16
)
else
:
# Reduced search space for faster tuning.
# Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to
# TODO(woosuk): Increase the search space and use a performance model to
# prune the search space.
# prune the search space.
configs
:
List
[
BenchmarkConfig
]
=
[]
block_m_range
=
[
16
,
32
,
64
,
128
,
256
]
for
num_stages
in
[
2
,
3
,
4
,
5
]:
block_n_range
=
[
32
,
64
,
128
,
256
]
for
block_m
in
[
16
,
32
,
64
,
128
,
256
]:
block_k_range
=
[
64
,
128
,
256
]
for
block_k
in
[
64
,
128
,
256
]:
num_warps_range
=
[
4
,
8
]
for
block_n
in
[
32
,
64
,
128
,
256
]:
group_m_range
=
[
1
,
16
,
32
,
64
]
for
num_warps
in
[
4
,
8
]:
num_stage_range
=
[
2
,
3
,
4
,
5
]
for
group_size
in
[
1
,
16
,
32
,
64
]:
configs
.
append
({
param_ranges
=
{
"BLOCK_SIZE_M"
:
block_m
,
"BLOCK_SIZE_M"
:
block_m_range
,
"BLOCK_SIZE_N"
:
block_n
,
"BLOCK_SIZE_N"
:
block_n_range
,
"BLOCK_SIZE_K"
:
block_k
,
"BLOCK_SIZE_K"
:
block_k_range
,
"GROUP_SIZE_M"
:
group_size
,
"GROUP_SIZE_M"
:
group_m_range
,
"num_warps"
:
num_warps
,
"num_warps"
:
num_warps_range
,
"num_stages"
:
num_stages
,
"num_stages"
:
num_stage_range
,
})
}
keys
,
values
=
zip
(
*
param_ranges
.
items
())
for
config_values
in
product
(
*
values
):
config
=
dict
(
zip
(
keys
,
config_values
))
configs
.
append
(
config
)
return
configs
return
configs
def
prune_rocm_search_space
(
num_tokens
,
shard_intermediate_size
,
hidden_size
,
search_space
,
is_fp16
):
N1
,
K1
=
shard_intermediate_size
,
hidden_size
N2
,
K2
=
hidden_size
,
shard_intermediate_size
//
2
pruned_space_1
=
prune_rocm_configs
(
num_tokens
*
2
,
N1
,
K1
,
search_space
,
is_fp16
)
pruned_space_2
=
prune_rocm_configs
(
num_tokens
*
2
,
N2
,
K2
,
search_space
,
is_fp16
)
search_space
=
merge_unique_dicts
(
pruned_space_1
,
pruned_space_2
)
return
search_space
# The following code is inspired by ROCm/Triton GEMM tuning script:
# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
def
prune_rocm_configs
(
M
,
N
,
K
,
configs
,
is_fp16
=
True
):
pruned_configs
=
[]
elemBytes_a
=
2
if
is_fp16
else
1
elemBytes_b
=
2
if
is_fp16
else
1
mfma
=
16
if
M
<
32
or
N
<
32
else
32
# TODO (zhanglx): figure out the boundary between large and small gemms
large_gemm
=
False
if
M
>=
2048
and
N
>=
2048
:
large_gemm
=
True
for
config
in
configs
:
BLOCK_SIZE_M
=
config
.
get
(
"BLOCK_SIZE_M"
)
BLOCK_SIZE_N
=
config
.
get
(
"BLOCK_SIZE_N"
)
BLOCK_SIZE_K
=
config
.
get
(
"BLOCK_SIZE_K"
)
num_warps
=
config
.
get
(
"num_warps"
)
if
is_fp16
:
matrix_instr_nonkdim
=
config
.
get
(
"matrix_instr_nonkdim"
)
if
matrix_instr_nonkdim
>
mfma
:
continue
if
mfma
==
4
and
BLOCK_SIZE_K
<
64
:
continue
# some layouts could not work properly in case
# number elements per thread is less 1
if
BLOCK_SIZE_M
*
BLOCK_SIZE_N
<
64
:
continue
SPLIT_K
=
config
.
get
(
"SPLIT_K"
,
1
)
GROUP_M
=
config
.
get
(
"GROUP_SIZE_M"
)
if
is_fp16
:
if
(
matrix_instr_nonkdim
>
BLOCK_SIZE_M
or
matrix_instr_nonkdim
>
BLOCK_SIZE_N
):
continue
if
(
matrix_instr_nonkdim
>=
M
and
matrix_instr_nonkdim
!=
BLOCK_SIZE_M
):
continue
if
(
matrix_instr_nonkdim
>=
N
and
matrix_instr_nonkdim
!=
BLOCK_SIZE_N
):
continue
# Skip BLOCK_SIZE that is too large compare to M/N
# unless BLOCK_SIZE is already small enough
if
M
*
2
<
BLOCK_SIZE_M
and
BLOCK_SIZE_M
!=
16
:
continue
if
N
*
2
<
BLOCK_SIZE_N
and
BLOCK_SIZE_N
!=
16
:
continue
# skip large split_k when not necessary
if
SPLIT_K
!=
1
and
not
need_split_k
(
M
,
N
,
K
):
continue
# skip split_k that leads to EVEN_K = false
leap
=
SPLIT_K
*
BLOCK_SIZE_K
modv
=
K
%
leap
if
modv
!=
0
:
continue
# skip large GROUP_M
if
GROUP_M
*
BLOCK_SIZE_M
>
M
and
GROUP_M
!=
1
:
continue
# out of shared memory resource
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
LDS
=
(
BLOCK_SIZE_K
*
BLOCK_SIZE_M
*
elemBytes_a
+
BLOCK_SIZE_K
*
BLOCK_SIZE_N
*
elemBytes_b
)
if
LDS
>
65536
:
continue
# Skip small block sizes and num_warps for large gemm
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
if
large_gemm
:
if
BLOCK_SIZE_M
<
64
or
BLOCK_SIZE_N
<
64
:
continue
if
BLOCK_SIZE_K
<
64
:
continue
if
num_warps
<
4
:
continue
pruned_configs
.
append
(
config
)
return
pruned_configs
def
need_split_k
(
SIZE_M
,
SIZE_N
,
SIZE_K
):
return
(
SIZE_M
<
64
or
SIZE_N
<
64
)
and
SIZE_K
>
1024
def
merge_unique_dicts
(
list1
,
list2
):
result
=
[]
combined_list
=
list1
.
copy
()
combined_list
.
extend
(
list2
)
for
dictionary
in
combined_list
:
if
dictionary
not
in
result
:
result
.
append
(
dictionary
)
return
result
@
ray
.
remote
(
num_gpus
=
1
)
@
ray
.
remote
(
num_gpus
=
1
)
class
BenchmarkWorker
:
class
BenchmarkWorker
:
...
@@ -170,6 +318,10 @@ class BenchmarkWorker:
...
@@ -170,6 +318,10 @@ class BenchmarkWorker:
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
self
.
seed
=
seed
self
.
seed
=
seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU. This is required for Ray to work
# correctly with multi-GPU tuning on the ROCm platform.
self
.
device_id
=
int
(
ray
.
get_gpu_ids
()[
0
])
def
benchmark
(
def
benchmark
(
self
,
self
,
...
@@ -217,6 +369,14 @@ class BenchmarkWorker:
...
@@ -217,6 +369,14 @@ class BenchmarkWorker:
)
->
Dict
[
str
,
int
]:
)
->
Dict
[
str
,
int
]:
best_config
=
None
best_config
=
None
best_time
=
float
(
"inf"
)
best_time
=
float
(
"inf"
)
if
current_platform
.
is_rocm
():
is_fp16
=
not
(
use_fp8_w8a8
or
use_int8_w8a16
)
search_space
=
prune_rocm_search_space
(
num_tokens
,
shard_intermediate_size
,
hidden_size
,
search_space
,
is_fp16
)
with
torch
.
cuda
.
device
(
self
.
device_id
):
for
config
in
tqdm
(
search_space
):
for
config
in
tqdm
(
search_space
):
try
:
try
:
kernel_time
=
benchmark_config
(
config
,
kernel_time
=
benchmark_config
(
config
,
...
@@ -228,7 +388,7 @@ class BenchmarkWorker:
...
@@ -228,7 +388,7 @@ class BenchmarkWorker:
dtype
,
dtype
,
use_fp8_w8a8
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int8_w8a16
,
num_iters
=
1
0
)
num_iters
=
2
0
)
except
triton
.
runtime
.
autotuner
.
OutOfResources
:
except
triton
.
runtime
.
autotuner
.
OutOfResources
:
# Some configurations may be invalid and fail to compile.
# Some configurations may be invalid and fail to compile.
continue
continue
...
@@ -244,12 +404,27 @@ class BenchmarkWorker:
...
@@ -244,12 +404,27 @@ class BenchmarkWorker:
def
sort_config
(
config
:
BenchmarkConfig
)
->
BenchmarkConfig
:
def
sort_config
(
config
:
BenchmarkConfig
)
->
BenchmarkConfig
:
return
{
return
{
"BLOCK_SIZE_M"
:
config
[
"BLOCK_SIZE_M"
],
"BLOCK_SIZE_M"
:
"BLOCK_SIZE_N"
:
config
[
"BLOCK_SIZE_N"
],
config
[
"BLOCK_SIZE_M"
],
"BLOCK_SIZE_K"
:
config
[
"BLOCK_SIZE_K"
],
"BLOCK_SIZE_N"
:
"GROUP_SIZE_M"
:
config
[
"GROUP_SIZE_M"
],
config
[
"BLOCK_SIZE_N"
],
"num_warps"
:
config
[
"num_warps"
],
"BLOCK_SIZE_K"
:
"num_stages"
:
config
[
"num_stages"
],
config
[
"BLOCK_SIZE_K"
],
"GROUP_SIZE_M"
:
config
[
"GROUP_SIZE_M"
],
"num_warps"
:
config
[
"num_warps"
],
"num_stages"
:
config
[
"num_stages"
],
**
({
"waves_per_eu"
:
config
[
"waves_per_eu"
]
}
if
"waves_per_eu"
in
config
else
{}),
**
({
"matrix_instr_nonkdim"
:
config
[
"matrix_instr_nonkdim"
]
}
if
"matrix_instr_nonkdim"
in
config
else
{}),
**
({
"kpack"
:
config
[
"kpack"
]
}
if
"kpack"
in
config
else
{}),
}
}
...
@@ -294,7 +469,7 @@ def main(args: argparse.Namespace):
...
@@ -294,7 +469,7 @@ def main(args: argparse.Namespace):
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
hidden_size
=
config
.
hidden_size
hidden_size
=
config
.
hidden_size
dtype
=
config
.
torch_dtype
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
torch_dtype
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
...
@@ -322,7 +497,8 @@ def main(args: argparse.Namespace):
...
@@ -322,7 +497,8 @@ def main(args: argparse.Namespace):
return
ray
.
get
(
outputs
)
return
ray
.
get
(
outputs
)
if
args
.
tune
:
if
args
.
tune
:
search_space
=
get_configs_compute_bound
()
is_fp16
=
not
(
use_fp8_w8a8
or
use_int8_w8a16
)
search_space
=
get_configs_compute_bound
(
is_fp16
)
print
(
f
"Start tuning over
{
len
(
search_space
)
}
configurations..."
)
print
(
f
"Start tuning over
{
len
(
search_space
)
}
configurations..."
)
start
=
time
.
time
()
start
=
time
.
time
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment