Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bb5b6403
Unverified
Commit
bb5b6403
authored
Mar 03, 2025
by
Divakar Verma
Committed by
GitHub
Mar 04, 2025
Browse files
[core] moe fp8 block quant tuning support (#14068)
Signed-off-by:
Divakar Verma
<
divakar.verma@amd.com
>
parent
c060b714
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
129 additions
and
57 deletions
+129
-57
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+67
-31
vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json
...Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json
+62
-26
No files found.
benchmarks/kernels/benchmark_moe.py
View file @
bb5b6403
...
@@ -40,6 +40,7 @@ def benchmark_config(
...
@@ -40,6 +40,7 @@ def benchmark_config(
use_fp8_w8a8
:
bool
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int8_w8a16
:
bool
,
num_iters
:
int
=
100
,
num_iters
:
int
=
100
,
block_quant_shape
:
List
[
int
]
=
None
,
)
->
float
:
)
->
float
:
init_dtype
=
torch
.
float16
if
use_fp8_w8a8
else
dtype
init_dtype
=
torch
.
float16
if
use_fp8_w8a8
else
dtype
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
...
@@ -81,8 +82,24 @@ def benchmark_config(
...
@@ -81,8 +82,24 @@ def benchmark_config(
dtype
=
torch
.
float32
)
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
((
hidden_size
,
num_experts
),
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
((
hidden_size
,
num_experts
),
dtype
=
torch
.
float32
)
if
use_fp8_w8a8
:
if
use_fp8_w8a8
:
if
block_quant_shape
:
block_n
,
block_k
=
block_quant_shape
[
0
],
block_quant_shape
[
1
]
E
=
num_experts
N
=
shard_intermediate_size
//
2
K
=
hidden_size
factor_for_scale
=
1e-2
n_tiles_w1
=
(
2
*
N
+
block_n
-
1
)
//
block_n
n_tiles_w2
=
(
K
+
block_n
-
1
)
//
block_n
k_tiles_w1
=
(
K
+
block_k
-
1
)
//
block_k
k_tiles_w2
=
(
N
+
block_k
-
1
)
//
block_k
w1_scale
=
torch
.
rand
((
E
,
n_tiles_w1
,
k_tiles_w1
),
dtype
=
torch
.
float32
)
*
factor_for_scale
w2_scale
=
torch
.
rand
((
E
,
n_tiles_w2
,
k_tiles_w2
),
dtype
=
torch
.
float32
)
*
factor_for_scale
else
:
w1_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
w1_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
a1_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
a1_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
a2_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
a2_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
...
@@ -111,6 +128,7 @@ def benchmark_config(
...
@@ -111,6 +128,7 @@ def benchmark_config(
w2_scale
=
w2_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_quant_shape
,
)
)
# JIT compilation & warmup
# JIT compilation & warmup
...
@@ -175,7 +193,8 @@ def get_rocm_tuning_space(use_fp16):
...
@@ -175,7 +193,8 @@ def get_rocm_tuning_space(use_fp16):
return
param_ranges
return
param_ranges
def
get_configs_compute_bound
(
use_fp16
)
->
list
[
dict
[
str
,
int
]]:
def
get_configs_compute_bound
(
use_fp16
,
block_quant_shape
)
->
list
[
dict
[
str
,
int
]]:
configs
:
list
[
BenchmarkConfig
]
=
[]
configs
:
list
[
BenchmarkConfig
]
=
[]
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
...
@@ -204,17 +223,27 @@ def get_configs_compute_bound(use_fp16) -> list[dict[str, int]]:
...
@@ -204,17 +223,27 @@ def get_configs_compute_bound(use_fp16) -> list[dict[str, int]]:
for
config_values
in
product
(
*
values
):
for
config_values
in
product
(
*
values
):
config
=
dict
(
zip
(
keys
,
config_values
))
config
=
dict
(
zip
(
keys
,
config_values
))
configs
.
append
(
config
)
configs
.
append
(
config
)
# Remove configs that are not compatible with fp8 block quantization
# BLOCK_SIZE_K must be a multiple of block_k
# BLOCK_SIZE_N must be a multiple of block_n
if
block_quant_shape
is
not
None
and
not
use_fp16
:
block_n
,
block_k
=
block_quant_shape
[
0
],
block_quant_shape
[
1
]
for
config
in
configs
[:]:
if
config
[
"BLOCK_SIZE_K"
]
%
block_k
!=
0
or
config
[
"BLOCK_SIZE_N"
]
%
block_n
!=
0
:
configs
.
remove
(
config
)
return
configs
return
configs
def
prune_rocm_search_space
(
num_tokens
,
shard_intermediate_size
,
hidden_size
,
def
prune_rocm_search_space
(
num_tokens
,
shard_intermediate_size
,
hidden_size
,
search_space
,
is_fp16
):
search_space
,
is_fp16
,
topk
):
N1
,
K1
=
shard_intermediate_size
,
hidden_size
N1
,
K1
=
shard_intermediate_size
,
hidden_size
N2
,
K2
=
hidden_size
,
shard_intermediate_size
//
2
N2
,
K2
=
hidden_size
,
shard_intermediate_size
//
2
pruned_space_1
=
prune_rocm_configs
(
num_tokens
*
2
,
N1
,
K1
,
search_space
,
pruned_space_1
=
prune_rocm_configs
(
num_tokens
*
topk
,
N1
,
K1
,
is_fp16
)
search_space
,
is_fp16
)
pruned_space_2
=
prune_rocm_configs
(
num_tokens
*
2
,
N2
,
K2
,
search_space
,
pruned_space_2
=
prune_rocm_configs
(
num_tokens
*
topk
,
N2
,
K2
,
is_fp16
)
search_space
,
is_fp16
)
search_space
=
merge_unique_dicts
(
pruned_space_1
,
pruned_space_2
)
search_space
=
merge_unique_dicts
(
pruned_space_1
,
pruned_space_2
)
return
search_space
return
search_space
...
@@ -372,6 +401,7 @@ class BenchmarkWorker:
...
@@ -372,6 +401,7 @@ class BenchmarkWorker:
use_fp8_w8a8
:
bool
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int8_w8a16
:
bool
,
search_space
:
list
[
dict
[
str
,
int
]],
search_space
:
list
[
dict
[
str
,
int
]],
block_quant_shape
:
list
[
int
],
)
->
dict
[
str
,
int
]:
)
->
dict
[
str
,
int
]:
best_config
=
None
best_config
=
None
best_time
=
float
(
"inf"
)
best_time
=
float
(
"inf"
)
...
@@ -380,12 +410,13 @@ class BenchmarkWorker:
...
@@ -380,12 +410,13 @@ class BenchmarkWorker:
search_space
=
prune_rocm_search_space
(
num_tokens
,
search_space
=
prune_rocm_search_space
(
num_tokens
,
shard_intermediate_size
,
shard_intermediate_size
,
hidden_size
,
search_space
,
hidden_size
,
search_space
,
is_fp16
)
is_fp16
,
topk
)
with
torch
.
cuda
.
device
(
self
.
device_id
):
with
torch
.
cuda
.
device
(
self
.
device_id
):
for
config
in
tqdm
(
search_space
):
for
config
in
tqdm
(
search_space
):
try
:
try
:
kernel_time
=
benchmark_config
(
config
,
kernel_time
=
benchmark_config
(
config
,
num_tokens
,
num_tokens
,
num_experts
,
num_experts
,
shard_intermediate_size
,
shard_intermediate_size
,
...
@@ -394,7 +425,8 @@ class BenchmarkWorker:
...
@@ -394,7 +425,8 @@ class BenchmarkWorker:
dtype
,
dtype
,
use_fp8_w8a8
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int8_w8a16
,
num_iters
=
20
)
num_iters
=
20
,
block_quant_shape
=
block_quant_shape
)
except
triton
.
runtime
.
autotuner
.
OutOfResources
:
except
triton
.
runtime
.
autotuner
.
OutOfResources
:
# Some configurations may be invalid and fail to compile.
# Some configurations may be invalid and fail to compile.
continue
continue
...
@@ -436,8 +468,8 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
...
@@ -436,8 +468,8 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
def
save_configs
(
configs
:
dict
[
int
,
BenchmarkConfig
],
num_experts
:
int
,
def
save_configs
(
configs
:
dict
[
int
,
BenchmarkConfig
],
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int8_w8a16
:
bool
)
->
None
:
block_quant_shape
:
List
[
int
]
)
->
None
:
dtype_str
=
get_config_dtype_str
(
dtype
,
dtype_str
=
get_config_dtype_str
(
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
)
use_fp8_w8a8
=
use_fp8_w8a8
)
...
@@ -445,7 +477,7 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
...
@@ -445,7 +477,7 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
# is the intermediate size after silu_and_mul.
filename
=
get_config_file_name
(
num_experts
,
shard_intermediate_size
//
2
,
filename
=
get_config_file_name
(
num_experts
,
shard_intermediate_size
//
2
,
dtype_str
)
dtype_str
,
block_quant_shape
)
print
(
f
"Writing best config to
{
filename
}
..."
)
print
(
f
"Writing best config to
{
filename
}
..."
)
with
open
(
filename
,
"w"
)
as
f
:
with
open
(
filename
,
"w"
)
as
f
:
...
@@ -455,7 +487,7 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
...
@@ -455,7 +487,7 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
def
main
(
args
:
argparse
.
Namespace
):
def
main
(
args
:
argparse
.
Namespace
):
print
(
args
)
print
(
args
)
block_quant_shape
=
None
config
=
AutoConfig
.
from_pretrained
(
config
=
AutoConfig
.
from_pretrained
(
args
.
model
,
trust_remote_code
=
args
.
trust_remote_code
)
args
.
model
,
trust_remote_code
=
args
.
trust_remote_code
)
if
config
.
architectures
[
0
]
==
"DbrxForCausalLM"
:
if
config
.
architectures
[
0
]
==
"DbrxForCausalLM"
:
...
@@ -474,6 +506,7 @@ def main(args: argparse.Namespace):
...
@@ -474,6 +506,7 @@ def main(args: argparse.Namespace):
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
block_quant_shape
=
config
.
quantization_config
[
'weight_block_size'
]
else
:
else
:
# Default: Mixtral.
# Default: Mixtral.
E
=
config
.
num_local_experts
E
=
config
.
num_local_experts
...
@@ -511,26 +544,29 @@ def main(args: argparse.Namespace):
...
@@ -511,26 +544,29 @@ def main(args: argparse.Namespace):
if
args
.
tune
:
if
args
.
tune
:
is_fp16
=
not
(
use_fp8_w8a8
or
use_int8_w8a16
)
is_fp16
=
not
(
use_fp8_w8a8
or
use_int8_w8a16
)
search_space
=
get_configs_compute_bound
(
is_fp16
)
search_space
=
get_configs_compute_bound
(
is_fp16
,
block_quant_shape
)
print
(
f
"Start tuning over
{
len
(
search_space
)
}
configurations..."
)
print
(
f
"Start tuning over
{
len
(
search_space
)
}
configurations..."
)
start
=
time
.
time
()
start
=
time
.
time
()
configs
=
_distribute
(
configs
=
_distribute
(
"tune"
,
[(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
"tune"
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
search_space
)
[(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
search_space
,
block_quant_shape
)
for
batch_size
in
batch_sizes
])
for
batch_size
in
batch_sizes
])
best_configs
=
{
best_configs
=
{
M
:
sort_config
(
config
)
M
:
sort_config
(
config
)
for
M
,
config
in
zip
(
batch_sizes
,
configs
)
for
M
,
config
in
zip
(
batch_sizes
,
configs
)
}
}
save_configs
(
best_configs
,
E
,
shard_intermediate_size
,
hidden_size
,
save_configs
(
best_configs
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
)
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
block_quant_shape
)
end
=
time
.
time
()
end
=
time
.
time
()
print
(
f
"Tuning took
{
end
-
start
:.
2
f
}
seconds"
)
print
(
f
"Tuning took
{
end
-
start
:.
2
f
}
seconds"
)
else
:
else
:
outputs
=
_distribute
(
outputs
=
_distribute
(
"benchmark"
,
[(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
"benchmark"
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
)
[(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
block_quant_shape
)
for
batch_size
in
batch_sizes
])
for
batch_size
in
batch_sizes
])
for
batch_size
,
(
config
,
kernel_time
)
in
zip
(
batch_sizes
,
outputs
):
for
batch_size
,
(
config
,
kernel_time
)
in
zip
(
batch_sizes
,
outputs
):
...
...
vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json
View file @
bb5b6403
{
{
"1"
:
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
"waves_per_eu"
:
0
},
},
"2"
:
{
"2"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
1
6
,
"BLOCK_SIZE_N"
:
1
28
,
"BLOCK_SIZE_K"
:
256
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
"waves_per_eu"
:
0
},
},
"4"
:
{
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
"waves_per_eu"
:
0
},
},
...
@@ -31,15 +31,15 @@
...
@@ -31,15 +31,15 @@
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
"waves_per_eu"
:
0
},
},
"16"
:
{
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
"waves_per_eu"
:
0
...
@@ -49,13 +49,13 @@
...
@@ -49,13 +49,13 @@
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
"waves_per_eu"
:
0
},
},
"32"
:
{
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
2
,
"num_warps"
:
2
,
...
@@ -64,7 +64,7 @@
...
@@ -64,7 +64,7 @@
},
},
"48"
:
{
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
2
,
"num_warps"
:
2
,
...
@@ -73,7 +73,7 @@
...
@@ -73,7 +73,7 @@
},
},
"64"
:
{
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_warps"
:
2
,
...
@@ -82,46 +82,82 @@
...
@@ -82,46 +82,82 @@
},
},
"96"
:
{
"96"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"GROUP_SIZE_M"
:
8
,
"num_warps"
:
4
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
"waves_per_eu"
:
0
},
},
"128"
:
{
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
"waves_per_eu"
:
0
},
},
"256"
:
{
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"GROUP_SIZE_M"
:
8
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
"waves_per_eu"
:
0
},
},
"512"
:
{
"512"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"GROUP_SIZE_M"
:
8
,
"num_warps"
:
8
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
"waves_per_eu"
:
0
},
},
"1024"
:
{
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"1536"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"GROUP_SIZE_M"
:
8
,
"num_warps"
:
8
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
"waves_per_eu"
:
0
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment