Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
68886f7d
Commit
68886f7d
authored
Jun 14, 2022
by
raman jana
Browse files
merging with latest develop branch
parents
a9ee2960
1677cf70
Changes
328
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1258 additions
and
297 deletions
+1258
-297
script/profile_conv.sh
script/profile_conv.sh
+52
-52
script/test_reduce_no_index.sh
script/test_reduce_no_index.sh
+11
-0
script/test_reduce_with_index.sh
script/test_reduce_with_index.sh
+11
-0
test/CMakeLists.txt
test/CMakeLists.txt
+6
-1
test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp
test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp
+4
-4
test/block_to_ctile_map/CMakeLists.txt
test/block_to_ctile_map/CMakeLists.txt
+1
-0
test/block_to_ctile_map/test_block_to_ctile_map.cpp
test/block_to_ctile_map/test_block_to_ctile_map.cpp
+318
-0
test/client_app/CMakeLists.txt
test/client_app/CMakeLists.txt
+11
-0
test/client_app/client_app.cpp
test/client_app/client_app.cpp
+77
-0
test/client_app/client_app_impl.hpp
test/client_app/client_app_impl.hpp
+214
-0
test/conv2d_bwd_weight/conv2d_bwd_weight.cpp
test/conv2d_bwd_weight/conv2d_bwd_weight.cpp
+16
-16
test/convnd_bwd_data/convnd_bwd_data.cpp
test/convnd_bwd_data/convnd_bwd_data.cpp
+48
-48
test/gemm/CMakeLists.txt
test/gemm/CMakeLists.txt
+26
-12
test/gemm/gemm_dl_fp16.cpp
test/gemm/gemm_dl_fp16.cpp
+135
-132
test/gemm/gemm_dl_fp32.cpp
test/gemm/gemm_dl_fp32.cpp
+133
-0
test/gemm/gemm_dl_int8.cpp
test/gemm/gemm_dl_int8.cpp
+133
-0
test/gemm/gemm_util.hpp
test/gemm/gemm_util.hpp
+46
-26
test/gemm/gemm_xdl_bf16.cpp
test/gemm/gemm_xdl_bf16.cpp
+0
-0
test/gemm/gemm_xdl_fp16.cpp
test/gemm/gemm_xdl_fp16.cpp
+8
-3
test/gemm/gemm_xdl_fp32.cpp
test/gemm/gemm_xdl_fp32.cpp
+8
-3
No files found.
script/profile_conv.sh
View file @
68886f7d
...
...
@@ -3,9 +3,9 @@
## GPU visibility
export
HIP_VISIBLE_DEVICES
=
0
make
-j
ckProfiler
#
make -j ckProfiler
DRIVER
=
".
/profiler
/ckProfiler"
DRIVER
=
".
./build/bin
/ckProfiler"
OP
=
$1
DATATYPE
=
$2
...
...
@@ -51,56 +51,56 @@ REPEAT=$9
# Resnet50 from Bing
#############
####### op_________________
___
datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C_
__
Y X Hi_
_
Wi__ Strides Dilations LeftPads RightPads
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 3 7 7 224 224 2 2 1 1 3 3 3 3
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 1 1 56 56 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 256 1 1 56 56 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1
#profiler/ckProfiler
conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 256 1 1 56 56 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1
#profiler/ckProfiler
conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 256 1 1 56 56 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 56 56 2 2 1 1 1 1 1 1
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1
#profiler/ckProfiler
conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1
#profiler/ckProfiler
conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1
#profiler/ckProfiler
conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 512 1 1 28 28 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 28 28 2 2 1 1 1 1 1 1
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1
#profiler/ckProfiler
conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1
#profiler/ckProfiler
conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1
#profiler/ckProfiler
conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1
#profiler/ckProfiler
conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1
#profiler/ckProfiler
conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 1024 1 1 14 14 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 14 14 2 2 1 1 1 1 1 1
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 2048 1 1 7 7 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 7 7 1 1 1 1 1 1 1 1
#profiler/ckProfiler
conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 2048 1 1 7 7 1 1 1 1 0 0 0 0
#profiler/ckProfiler
conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 7 7 1 1 1 1 1 1 1 1
#profiler/ckProfiler
conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0
####### op_________________
datatype in_layout wei_layout out_layout verify init log repeat
N__
K___ C_ Y X
Hi_ Wi__ Strides Dilations LeftPads RightPads
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
64 3 7 7 224 224 2 2 1 1 3 3 3 3
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
64 64 1 1 56 56 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
64 64 3 3 56 56 1 1 1 1 1 1 1 1
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
256 64 1 1 56 56 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
64 256 1 1 56 56 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
64 64 3 3 56 56 1 1 1 1 1 1 1 1
$DRIVER
conv_fwd_bias_relu_add
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
256 64 1 1 56 56 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
64 256 1 1 56 56 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
64 64 3 3 56 56 1 1 1 1 1 1 1 1
$DRIVER
conv_fwd_bias_relu_add
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
256 64 1 1 56 56 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
128 256 1 1 56 56 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
128 128 3 3 56 56 2 2 1 1 1 1 1 1
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
512 128 1 1 28 28 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
128 512 1 1 28 28 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
128 128 3 3 28 28 1 1 1 1 1 1 1 1
$DRIVER
conv_fwd_bias_relu_add
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
512 128 1 1 28 28 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
128 512 1 1 28 28 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
128 128 3 3 28 28 1 1 1 1 1 1 1 1
$DRIVER
conv_fwd_bias_relu_add
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
512 128 1 1 28 28 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
128 512 1 1 28 28 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
128 128 3 3 28 28 1 1 1 1 1 1 1 1
$DRIVER
conv_fwd_bias_relu_add
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
512 128 1 1 28 28 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
256 512 1 1 28 28 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
256 256 3 3 28 28 2 2 1 1 1 1 1 1
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
1024 256 1 1 14 14 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
256 1024 1 1 14 14 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
256 256 3 3 14 14 1 1 1 1 1 1 1 1
$DRIVER
conv_fwd_bias_relu_add
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
1024 256 1 1 14 14 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
256 1024 1 1 14 14 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
256 256 3 3 14 14 1 1 1 1 1 1 1 1
$DRIVER
conv_fwd_bias_relu_add
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
1024 256 1 1 14 14 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
256 1024 1 1 14 14 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
256 256 3 3 14 14 1 1 1 1 1 1 1 1
$DRIVER
conv_fwd_bias_relu_add
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
1024 256 1 1 14 14 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
256 1024 1 1 14 14 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
256 256 3 3 14 14 1 1 1 1 1 1 1 1
$DRIVER
conv_fwd_bias_relu_add
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
1024 256 1 1 14 14 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
256 1024 1 1 14 14 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
256 256 3 3 14 14 1 1 1 1 1 1 1 1
$DRIVER
conv_fwd_bias_relu_add
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
1024 256 1 1 14 14 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
512 1024 1 1 14 14 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
512 512 3 3 14 14 2 2 1 1 1 1 1 1
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
2048 512 1 1 7 7 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
512 2048 1 1 7 7 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
512 512 3 3 7 7 1 1 1 1 1 1 1 1
$DRIVER
conv_fwd_bias_relu_add
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
2048 512 1 1 7 7 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
512 2048 1 1 7 7 1 1 1 1 0 0 0 0
$DRIVER
conv_fwd_bias_relu
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
512 512 3 3 7 7 1 1 1 1 1 1 1 1
$DRIVER
conv_fwd_bias_relu_add
$DATATYPE
$IN_LAYOUT
$WEI_LAYOUT
$OUT_LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
$N
2048 512 1 1 7 7 1 1 1 1 0 0 0 0
# Resnet50
...
...
script/test_reduce_no_index.sh
View file @
68886f7d
...
...
@@ -15,6 +15,17 @@ bin/test_reduce_no_index -D 64,4,280,82 -R 1 0 2
bin/test_reduce_no_index
-D
64,4,280,82
-R
2 0 2
bin/test_reduce_no_index
-D
64,4,280,82
-R
3 0 2
## for float64
bin/test_reduce_no_index
-D
64,4,280,82
-R
0,1,2,3 6 2
bin/test_reduce_no_index
-D
64,4,280,82
-R
0,1,2 6 2
bin/test_reduce_no_index
-D
64,4,280,82
-R
0,1,3 6 2
bin/test_reduce_no_index
-D
64,4,280,82
-R
0,2,3 6 2
bin/test_reduce_no_index
-D
64,4,280,82
-R
1,2,3 6 2
bin/test_reduce_no_index
-D
64,4,280,82
-R
0 6 2
bin/test_reduce_no_index
-D
64,4,280,82
-R
1 6 2
bin/test_reduce_no_index
-D
64,4,280,82
-R
2 6 2
bin/test_reduce_no_index
-D
64,4,280,82
-R
3 6 2
## for float16
bin/test_reduce_no_index
-D
64,4,280,82
-R
0,1,2,3 1 2
bin/test_reduce_no_index
-D
64,4,280,82
-R
0,1,2 1 2
...
...
script/test_reduce_with_index.sh
View file @
68886f7d
...
...
@@ -15,6 +15,17 @@ bin/test_reduce_with_index -D 64,4,280,82 -R 1 0 2
bin/test_reduce_with_index
-D
64,4,280,82
-R
2 0 2
bin/test_reduce_with_index
-D
64,4,280,82
-R
3 0 2
## for float64
bin/test_reduce_with_index
-D
64,4,280,82
-R
0,1,2,3 6 2
bin/test_reduce_with_index
-D
64,4,280,82
-R
0,1,2 6 2
bin/test_reduce_with_index
-D
64,4,280,82
-R
0,1,3 6 2
bin/test_reduce_with_index
-D
64,4,280,82
-R
0,2,3 6 2
bin/test_reduce_with_index
-D
64,4,280,82
-R
1,2,3 6 2
bin/test_reduce_with_index
-D
64,4,280,82
-R
0 6 2
bin/test_reduce_with_index
-D
64,4,280,82
-R
1 6 2
bin/test_reduce_with_index
-D
64,4,280,82
-R
2 6 2
bin/test_reduce_with_index
-D
64,4,280,82
-R
3 6 2
## for float16
bin/test_reduce_with_index
-D
64,4,280,82
-R
0,1,2,3 1 2
bin/test_reduce_with_index
-D
64,4,280,82
-R
0,1,2 1 2
...
...
test/CMakeLists.txt
View file @
68886f7d
...
...
@@ -2,6 +2,7 @@ include_directories(BEFORE
${
PROJECT_SOURCE_DIR
}
/
${
PROJECT_SOURCE_DIR
}
/include/ck
${
PROJECT_SOURCE_DIR
}
/include/ck/utility
${
PROJECT_SOURCE_DIR
}
/include/ck/host_utility
${
PROJECT_SOURCE_DIR
}
/include/ck/tensor_description
${
PROJECT_SOURCE_DIR
}
/include/ck/tensor
${
PROJECT_SOURCE_DIR
}
/include/ck/problem_transform
...
...
@@ -22,7 +23,8 @@ include_directories(BEFORE
${
PROJECT_SOURCE_DIR
}
/external/include/half
)
add_custom_target
(
check COMMAND
${
CMAKE_CTEST_COMMAND
}
--output-on-failure -C
${
CMAKE_CFG_INTDIR
}
)
include
(
googletest
)
add_custom_target
(
tests
)
...
...
@@ -61,3 +63,6 @@ add_subdirectory(grouped_gemm)
add_subdirectory
(
convnd_fwd
)
add_subdirectory
(
reduce
)
add_subdirectory
(
conv2d_bwd_weight
)
add_subdirectory
(
convnd_bwd_data
)
add_subdirectory
(
block_to_ctile_map
)
# DONOT add client_app, that is tested via CI independently
test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp
View file @
68886f7d
...
...
@@ -22,7 +22,7 @@ int main()
Row
,
Row
,
Row
>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
N
,
N
,
BatchCount
);
true
,
1
,
false
,
false
,
M
,
N
,
K
,
K
,
N
,
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_reduce_impl
<
ck
::
half_t
,
ck
::
half_t
,
...
...
@@ -31,7 +31,7 @@ int main()
Row
,
Col
,
Row
>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
K
,
N
,
BatchCount
);
true
,
1
,
false
,
false
,
M
,
N
,
K
,
K
,
K
,
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_reduce_impl
<
ck
::
half_t
,
ck
::
half_t
,
...
...
@@ -40,7 +40,7 @@ int main()
Col
,
Row
,
Row
>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
N
,
N
,
BatchCount
);
true
,
1
,
false
,
false
,
M
,
N
,
K
,
M
,
N
,
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_reduce_impl
<
ck
::
half_t
,
ck
::
half_t
,
...
...
@@ -49,7 +49,7 @@ int main()
Col
,
Col
,
Row
>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
K
,
N
,
BatchCount
);
true
,
1
,
false
,
false
,
M
,
N
,
K
,
M
,
K
,
N
,
BatchCount
);
if
(
pass
)
{
...
...
test/block_to_ctile_map/CMakeLists.txt
0 → 100644
View file @
68886f7d
add_gtest_executable
(
test_block_to_ctile_map test_block_to_ctile_map.cpp
)
\ No newline at end of file
test/block_to_ctile_map/test_block_to_ctile_map.cpp
0 → 100644
View file @
68886f7d
#include <ck/config.hpp>
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "gtest/gtest.h"
#include <iostream>
#include <vector>
using
namespace
ck
;
static
auto
I0
=
Number
<
0
>
{};
static
auto
I1
=
Number
<
1
>
{};
static
auto
I2
=
Number
<
2
>
{};
TEST
(
BlockToCTileMap
,
TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck1
)
{
const
index_t
M
=
384
;
const
index_t
N
=
384
;
const
index_t
MPerBlock
=
128
;
const
index_t
NPerBlock
=
128
;
const
index_t
MBlock
=
M
/
MPerBlock
;
const
index_t
NBlock
=
N
/
NPerBlock
;
const
index_t
M01
=
4
;
const
index_t
N01
=
4
;
auto
c_grid_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
N
));
printf
(
"(M, N, MPerBlock, NPerBlock, M01, N01) = (%d, %d, %d, %d, %d, %d)
\n
"
,
M
,
N
,
MPerBlock
,
NPerBlock
,
M01
,
N01
);
BlockToCTileMap_M00_N00_M01_N01
<
MPerBlock
,
NPerBlock
,
decltype
(
c_grid_desc_m_n
),
true
>
tile_map
(
c_grid_desc_m_n
,
M01
,
N01
);
EXPECT_TRUE
(
tile_map
.
CheckValidity
(
c_grid_desc_m_n
)
==
true
);
EXPECT_TRUE
(
tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
)
==
16
);
// clang-format off
std
::
vector
<
std
::
vector
<
int
>>
expected_m0idx_n0idx_valid
=
{
{
0
,
0
,
1
},
{
0
,
1
,
1
},
{
0
,
2
,
1
},
{
0
,
3
,
0
},
{
1
,
0
,
1
},
{
1
,
1
,
1
},
{
1
,
2
,
1
},
{
1
,
3
,
0
},
{
2
,
0
,
1
},
{
2
,
1
,
1
},
{
2
,
2
,
1
},
{
2
,
3
,
0
},
{
3
,
0
,
0
},
{
3
,
1
,
0
},
{
3
,
2
,
0
},
{
3
,
3
,
0
}
};
// clang-format on
for
(
index_t
i
=
0
;
i
<
tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
i
++
)
{
auto
m0n0_idx
=
tile_map
.
CalculateBottomIndex
(
make_multi_index
(
i
));
std
::
cout
<<
"block_1d_id = "
<<
i
<<
", m0, n0 = "
<<
m0n0_idx
[
I0
]
<<
", "
<<
m0n0_idx
[
I1
];
std
::
cout
<<
", valid = "
<<
tile_map
.
ValidCTileIndex
(
m0n0_idx
,
make_tuple
(
MBlock
,
NBlock
))
<<
std
::
endl
;
bool
equal
=
expected_m0idx_n0idx_valid
[
i
]
==
std
::
vector
<
int
>
{
m0n0_idx
[
I0
],
m0n0_idx
[
I1
],
tile_map
.
ValidCTileIndex
(
m0n0_idx
,
make_tuple
(
MBlock
,
NBlock
))};
EXPECT_TRUE
(
equal
);
}
}
TEST
(
BlockToCTileMap
,
TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck0
)
{
const
index_t
M
=
384
;
const
index_t
N
=
384
;
const
index_t
MPerBlock
=
128
;
const
index_t
NPerBlock
=
128
;
const
index_t
M01
=
4
;
const
index_t
N01
=
4
;
auto
c_grid_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
N
));
printf
(
"(M, N, MPerBlock, NPerBlock, M01, N01) = (%d, %d, %d, %d, %d, %d)
\n
"
,
M
,
N
,
MPerBlock
,
NPerBlock
,
M01
,
N01
);
BlockToCTileMap_M00_N00_M01_N01
<
MPerBlock
,
NPerBlock
,
decltype
(
c_grid_desc_m_n
),
false
>
tile_map
(
c_grid_desc_m_n
,
M01
,
N01
);
EXPECT_TRUE
(
tile_map
.
CheckValidity
(
c_grid_desc_m_n
)
==
false
);
}
TEST
(
BlockToCTileMap
,
TestBlockToCTileMap_M00_N0_M01_DeviceCTileIndexCheck1
)
{
const
index_t
M
=
384
;
const
index_t
N
=
512
;
const
index_t
MPerBlock
=
128
;
const
index_t
NPerBlock
=
128
;
const
index_t
MBlock
=
M
/
MPerBlock
;
const
index_t
NBlock
=
N
/
NPerBlock
;
const
index_t
M01
=
4
;
auto
c_grid_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
N
));
printf
(
"(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)
\n
"
,
M
,
N
,
MPerBlock
,
NPerBlock
,
M01
);
BlockToCTileMap_M00_N0_M01
<
MPerBlock
,
NPerBlock
,
decltype
(
c_grid_desc_m_n
),
true
>
tile_map
(
c_grid_desc_m_n
,
M01
);
EXPECT_TRUE
(
tile_map
.
CheckValidity
(
c_grid_desc_m_n
)
==
true
);
EXPECT_TRUE
(
tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
)
==
16
);
// clang-format off
std
::
vector
<
std
::
vector
<
int
>>
expected_m0idx_n0idx_valid
=
{
{
0
,
0
,
1
},
{
1
,
0
,
1
},
{
2
,
0
,
1
},
{
3
,
0
,
0
},
{
0
,
1
,
1
},
{
1
,
1
,
1
},
{
2
,
1
,
1
},
{
3
,
1
,
0
},
{
0
,
2
,
1
},
{
1
,
2
,
1
},
{
2
,
2
,
1
},
{
3
,
2
,
0
},
{
0
,
3
,
1
},
{
1
,
3
,
1
},
{
2
,
3
,
1
},
{
3
,
3
,
0
}
};
// clang-format on
for
(
index_t
i
=
0
;
i
<
tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
i
++
)
{
auto
m0n0_idx
=
tile_map
.
CalculateBottomIndex
(
make_multi_index
(
i
));
std
::
cout
<<
"block_1d_id = "
<<
i
<<
", m0, n0 = "
<<
m0n0_idx
[
I0
]
<<
", "
<<
m0n0_idx
[
I1
];
std
::
cout
<<
", valid = "
<<
tile_map
.
ValidCTileIndex
(
m0n0_idx
,
make_tuple
(
MBlock
,
NBlock
))
<<
std
::
endl
;
bool
equal
=
expected_m0idx_n0idx_valid
[
i
]
==
std
::
vector
<
int
>
{
m0n0_idx
[
I0
],
m0n0_idx
[
I1
],
tile_map
.
ValidCTileIndex
(
m0n0_idx
,
make_tuple
(
MBlock
,
NBlock
))};
EXPECT_TRUE
(
equal
);
}
}
TEST
(
BlockToCTileMap
,
TestBlockToCTileMap_M00_N0_M01_DeviceCTileIndexCheck0
)
{
const
index_t
M
=
512
;
const
index_t
N
=
384
;
const
index_t
MPerBlock
=
128
;
const
index_t
NPerBlock
=
128
;
auto
c_grid_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
N
));
// clang-format off
std
::
vector
<
std
::
tuple
<
int
,
int
,
bool
>>
expected_m0_gridsize_validity
=
{
{
5
,
15
,
false
},
{
4
,
12
,
true
},
{
3
,
18
,
false
},
{
2
,
12
,
true
},
{
1
,
12
,
true
}
};
// clang-format on
for
(
auto
e
:
expected_m0_gridsize_validity
)
{
const
index_t
M01
=
std
::
get
<
0
>
(
e
);
printf
(
"(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)
\n
"
,
M
,
N
,
MPerBlock
,
NPerBlock
,
M01
);
BlockToCTileMap_M00_N0_M01
<
MPerBlock
,
NPerBlock
,
decltype
(
c_grid_desc_m_n
),
false
>
tile_map
(
c_grid_desc_m_n
,
M01
);
EXPECT_EQ
(
tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
),
std
::
get
<
1
>
(
e
));
EXPECT_EQ
(
tile_map
.
CheckValidity
(
c_grid_desc_m_n
),
std
::
get
<
2
>
(
e
));
}
}
TEST
(
BlockToCTileMap
,
TestBlockToCTileMap_M00_N0_M01Adapt
)
{
const
index_t
M
=
768
;
const
index_t
N
=
384
;
const
index_t
MPerBlock
=
128
;
const
index_t
NPerBlock
=
128
;
const
index_t
MBlock
=
M
/
MPerBlock
;
const
index_t
NBlock
=
N
/
NPerBlock
;
constexpr
index_t
M01
=
4
;
auto
c_grid_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
N
));
printf
(
"(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)
\n
"
,
M
,
N
,
MPerBlock
,
NPerBlock
,
M01
);
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
decltype
(
c_grid_desc_m_n
)
>
tile_map
(
c_grid_desc_m_n
,
M01
);
EXPECT_TRUE
(
tile_map
.
CheckValidity
(
c_grid_desc_m_n
)
==
true
);
EXPECT_TRUE
(
tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
)
==
18
);
// clang-format off
std
::
vector
<
std
::
vector
<
int
>>
expected_m0idx_n0idx_valid
=
{
{
0
,
0
,
1
},
{
1
,
0
,
1
},
{
2
,
0
,
1
},
{
3
,
0
,
1
},
{
0
,
1
,
1
},
{
1
,
1
,
1
},
{
2
,
1
,
1
},
{
3
,
1
,
1
},
{
0
,
2
,
1
},
{
1
,
2
,
1
},
{
2
,
2
,
1
},
{
3
,
2
,
1
},
{
4
,
0
,
1
},
{
5
,
0
,
1
},
{
4
,
1
,
1
},
{
5
,
1
,
1
},
{
4
,
2
,
1
},
{
5
,
2
,
1
},
};
// clang-format on
for
(
index_t
i
=
0
;
i
<
tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
i
++
)
{
auto
m0n0_idx
=
tile_map
.
CalculateBottomIndex
(
make_multi_index
(
i
));
std
::
cout
<<
"block_1d_id = "
<<
i
<<
", m0, n0 = "
<<
m0n0_idx
[
I0
]
<<
", "
<<
m0n0_idx
[
I1
];
std
::
cout
<<
", valid = "
<<
tile_map
.
ValidCTileIndex
(
m0n0_idx
,
make_tuple
(
MBlock
,
NBlock
))
<<
std
::
endl
;
bool
equal
=
expected_m0idx_n0idx_valid
[
i
]
==
std
::
vector
<
int
>
{
m0n0_idx
[
I0
],
m0n0_idx
[
I1
],
tile_map
.
ValidCTileIndex
(
m0n0_idx
,
make_tuple
(
MBlock
,
NBlock
))};
EXPECT_TRUE
(
equal
);
}
}
TEST
(
BlockToCTileMap
,
TestBlockToCTileMap_KSplit_M00_N0_M01Adapt
)
{
const
index_t
M
=
768
;
const
index_t
N
=
384
;
const
index_t
MPerBlock
=
128
;
const
index_t
NPerBlock
=
128
;
const
index_t
MBlock
=
M
/
MPerBlock
;
const
index_t
NBlock
=
N
/
NPerBlock
;
constexpr
index_t
M01
=
4
;
const
index_t
KSplit
=
3
;
auto
c_grid_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
N
));
printf
(
"(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)
\n
"
,
M
,
N
,
MPerBlock
,
NPerBlock
,
M01
);
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
decltype
(
c_grid_desc_m_n
)
>
tile_map
(
c_grid_desc_m_n
,
M01
,
KSplit
);
EXPECT_TRUE
(
tile_map
.
CheckValidity
(
c_grid_desc_m_n
)
==
true
);
EXPECT_TRUE
(
tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
)
==
18
*
KSplit
);
std
::
vector
<
std
::
vector
<
int
>>
expected_ksplitidx_m0idx_n0idx_valid
=
{
{
0
,
0
,
0
,
1
},
{
0
,
1
,
0
,
1
},
{
0
,
2
,
0
,
1
},
{
0
,
3
,
0
,
1
},
{
0
,
0
,
1
,
1
},
{
0
,
1
,
1
,
1
},
{
0
,
2
,
1
,
1
},
{
0
,
3
,
1
,
1
},
{
0
,
0
,
2
,
1
},
{
0
,
1
,
2
,
1
},
{
0
,
2
,
2
,
1
},
{
0
,
3
,
2
,
1
},
{
0
,
4
,
0
,
1
},
{
0
,
5
,
0
,
1
},
{
0
,
4
,
1
,
1
},
{
0
,
5
,
1
,
1
},
{
0
,
4
,
2
,
1
},
{
0
,
5
,
2
,
1
},
{
1
,
0
,
0
,
1
},
{
1
,
1
,
0
,
1
},
{
1
,
2
,
0
,
1
},
{
1
,
3
,
0
,
1
},
{
1
,
0
,
1
,
1
},
{
1
,
1
,
1
,
1
},
{
1
,
2
,
1
,
1
},
{
1
,
3
,
1
,
1
},
{
1
,
0
,
2
,
1
},
{
1
,
1
,
2
,
1
},
{
1
,
2
,
2
,
1
},
{
1
,
3
,
2
,
1
},
{
1
,
4
,
0
,
1
},
{
1
,
5
,
0
,
1
},
{
1
,
4
,
1
,
1
},
{
1
,
5
,
1
,
1
},
{
1
,
4
,
2
,
1
},
{
1
,
5
,
2
,
1
},
{
2
,
0
,
0
,
1
},
{
2
,
1
,
0
,
1
},
{
2
,
2
,
0
,
1
},
{
2
,
3
,
0
,
1
},
{
2
,
0
,
1
,
1
},
{
2
,
1
,
1
,
1
},
{
2
,
2
,
1
,
1
},
{
2
,
3
,
1
,
1
},
{
2
,
0
,
2
,
1
},
{
2
,
1
,
2
,
1
},
{
2
,
2
,
2
,
1
},
{
2
,
3
,
2
,
1
},
{
2
,
4
,
0
,
1
},
{
2
,
5
,
0
,
1
},
{
2
,
4
,
1
,
1
},
{
2
,
5
,
1
,
1
},
{
2
,
4
,
2
,
1
},
{
2
,
5
,
2
,
1
},
};
for
(
index_t
i
=
0
;
i
<
tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
i
++
)
{
auto
ksplitm0n0_idx
=
tile_map
.
CalculateBottomIndex
(
make_multi_index
(
i
));
std
::
cout
<<
"block_1d_id = "
<<
i
<<
", ksplit, m0, n0 = "
<<
ksplitm0n0_idx
[
I0
]
<<
", "
<<
ksplitm0n0_idx
[
I1
]
<<
", "
<<
ksplitm0n0_idx
[
I2
];
std
::
cout
<<
", valid = "
<<
tile_map
.
ValidCTileIndex
(
ksplitm0n0_idx
,
make_tuple
(
MBlock
,
NBlock
))
<<
std
::
endl
;
bool
equal
=
expected_ksplitidx_m0idx_n0idx_valid
[
i
]
==
std
::
vector
<
int
>
{
ksplitm0n0_idx
[
I0
],
ksplitm0n0_idx
[
I1
],
ksplitm0n0_idx
[
I2
],
tile_map
.
ValidCTileIndex
(
ksplitm0n0_idx
,
make_tuple
(
MBlock
,
NBlock
))};
EXPECT_TRUE
(
equal
);
}
}
test/client_app/CMakeLists.txt
0 → 100644
View file @
68886f7d
cmake_minimum_required
(
VERSION 3.15
)
project
(
ck_app
)
add_compile_options
(
-std=c++14
)
find_package
(
composable_kernel 1.0.0 COMPONENTS device_operations host_tensor
)
find_package
(
hip REQUIRED PATHS /opt/rocm
)
message
(
STATUS
"Build with HIP
${
hip_VERSION
}
"
)
add_executable
(
test_client_app client_app.cpp
)
target_link_libraries
(
test_client_app PRIVATE composable_kernel::device_operations composable_kernel::host_tensor hip::host
)
test/client_app/client_app.cpp
0 → 100644
View file @
68886f7d
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include <vector>
#include "client_app_impl.hpp"
int
main
(
int
argc
,
char
*
argv
[])
{
if
(
argc
!=
25
)
{
printf
(
"arg1: tensor operation (conv_fwd: ForwardConvolution)
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16)
\n
"
);
printf
(
"arg3: input tensor layout (0: NCHW; 1: NHWC)
\n
"
);
printf
(
"arg4: weight tensor layout (0: KCYX; 1: KYXC)
\n
"
);
printf
(
"arg5: output tensor layout (0: NKHW; 1: NHWK)
\n
"
);
printf
(
"arg6: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg7: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg8: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg9: time kernel (0=n0, 1=yes)
\n
"
);
printf
(
"arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx
\n
"
);
exit
(
1
);
}
const
ConvDataType
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
int
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
int
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
int
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
9
]);
const
ck
::
index_t
N
=
std
::
stoi
(
argv
[
10
]);
const
ck
::
index_t
K
=
std
::
stoi
(
argv
[
11
]);
const
ck
::
index_t
C
=
std
::
stoi
(
argv
[
12
]);
const
ck
::
index_t
Y
=
std
::
stoi
(
argv
[
13
]);
const
ck
::
index_t
X
=
std
::
stoi
(
argv
[
14
]);
const
ck
::
index_t
Hi
=
std
::
stoi
(
argv
[
15
]);
const
ck
::
index_t
Wi
=
std
::
stoi
(
argv
[
16
]);
const
ck
::
index_t
conv_stride_h
=
std
::
stoi
(
argv
[
17
]);
const
ck
::
index_t
conv_stride_w
=
std
::
stoi
(
argv
[
18
]);
const
ck
::
index_t
conv_dilation_h
=
std
::
stoi
(
argv
[
19
]);
const
ck
::
index_t
conv_dilation_w
=
std
::
stoi
(
argv
[
20
]);
const
ck
::
index_t
in_left_pad_h
=
std
::
stoi
(
argv
[
21
]);
const
ck
::
index_t
in_left_pad_w
=
std
::
stoi
(
argv
[
22
]);
const
ck
::
index_t
in_right_pad_h
=
std
::
stoi
(
argv
[
23
]);
const
ck
::
index_t
in_right_pad_w
=
std
::
stoi
(
argv
[
24
]);
const
ck
::
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
const
ck
::
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
const
ck
::
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
YEff
)
/
conv_stride_h
+
1
;
const
ck
::
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
1
;
ck
::
app
::
profile_conv_fwd_impl
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
data_type
,
N
,
K
,
C
,
std
::
vector
<
ck
::
index_t
>
{
Hi
,
Wi
},
std
::
vector
<
ck
::
index_t
>
{
Y
,
X
},
std
::
vector
<
ck
::
index_t
>
{
Ho
,
Wo
},
std
::
vector
<
ck
::
index_t
>
{
conv_stride_h
,
conv_stride_w
},
std
::
vector
<
ck
::
index_t
>
{
conv_dilation_h
,
conv_dilation_w
},
std
::
vector
<
ck
::
index_t
>
{
in_left_pad_h
,
in_left_pad_w
},
std
::
vector
<
ck
::
index_t
>
{
in_right_pad_h
,
in_right_pad_w
});
return
1
;
}
test/client_app/client_app_impl.hpp
0 → 100644
View file @
68886f7d
#pragma once
#include "host_interface.hpp"
enum
ConvDataType
{
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
BF16_BF16_BF16
,
// 2
INT8_INT8_INT8
,
// 3
};
enum
ConvInputLayout
{
NCHW
,
// 0
NHWC
,
// 1
};
enum
ConvWeightLayout
{
KCYX
,
// 0
KYXC
,
// 1
};
enum
ConvOutputLayout
{
NKHW
,
// 0
NHWK
,
// 1
};
void
check_hip_error
(
void
)
{
hipError_t
err
=
hipGetLastError
();
if
(
err
!=
hipSuccess
)
{
std
::
cerr
<<
"Error: "
<<
hipGetErrorString
(
err
)
<<
std
::
endl
;
exit
(
err
);
}
}
std
::
string
getDeviceName
(
int
device
)
{
struct
hipDeviceProp_t
prop
;
hipGetDeviceProperties
(
&
prop
,
device
);
check_hip_error
();
return
std
::
string
(
prop
.
name
);
}
int
getDriver
(
void
)
{
int
driver
;
hipDriverGetVersion
(
&
driver
);
check_hip_error
();
return
driver
;
}
namespace
ck
{
namespace
app
{
struct
DeviceMem
{
DeviceMem
()
=
delete
;
DeviceMem
(
std
::
size_t
mem_size
);
void
*
GetDeviceBuffer
();
void
ToDevice
(
const
void
*
p
);
void
FromDevice
(
void
*
p
);
~
DeviceMem
();
void
*
mpDeviceBuf
;
std
::
size_t
mMemSize
;
};
DeviceMem
::
DeviceMem
(
std
::
size_t
mem_size
)
:
mMemSize
(
mem_size
)
{
hipGetErrorString
(
hipMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
));
}
void
*
DeviceMem
::
GetDeviceBuffer
()
{
return
mpDeviceBuf
;
}
void
DeviceMem
::
ToDevice
(
const
void
*
p
)
{
hipGetErrorString
(
hipMemcpy
(
mpDeviceBuf
,
const_cast
<
void
*>
(
p
),
mMemSize
,
hipMemcpyHostToDevice
));
}
void
DeviceMem
::
FromDevice
(
void
*
p
)
{
hipGetErrorString
(
hipMemcpy
(
p
,
mpDeviceBuf
,
mMemSize
,
hipMemcpyDeviceToHost
));
}
DeviceMem
::~
DeviceMem
()
{
hipGetErrorString
(
hipFree
(
mpDeviceBuf
));
}
void
profile_conv_fwd_impl
(
int
do_verification
,
int
init_method
,
bool
do_log
,
bool
time_kernel
,
ConvDataType
data_type
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
{
const
ck
::
index_t
Y
=
filter_spatial_lengths
[
0
];
const
ck
::
index_t
X
=
filter_spatial_lengths
[
1
];
const
ck
::
index_t
Hi
=
input_spatial_lengths
[
0
];
const
ck
::
index_t
Wi
=
input_spatial_lengths
[
1
];
const
ck
::
index_t
Ho
=
output_spatial_lengths
[
0
];
const
ck
::
index_t
Wo
=
output_spatial_lengths
[
1
];
const
auto
in_sz
=
N
*
C
*
Hi
*
Wi
;
const
auto
wei_sz
=
K
*
C
*
Y
*
X
;
const
auto
out_sz
=
N
*
K
*
Ho
*
Wo
;
using
WeiDataType
=
float
;
using
InDataType
=
float
;
using
OutDataType
=
float
;
app
::
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_sz
);
app
::
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
wei_sz
);
app
::
DeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
out_sz
);
// data is already on device!
// add device Conv instances
std
::
vector
<
DeviceConvFwdPtr_t
>
conv_ptrs
;
if
(
data_type
==
F16_F16_F16
)
{
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t
(
conv_ptrs
);
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t
(
conv_ptrs
);
}
else
if
(
data_type
==
BF16_BF16_BF16
)
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t
(
conv_ptrs
);
else
if
(
data_type
==
F32_F32_F32
)
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t
(
conv_ptrs
);
else
if
(
data_type
==
INT8_INT8_INT8
)
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t
(
conv_ptrs
);
else
throw
std
::
runtime_error
(
"wrong! Invalid data type"
);
if
(
conv_ptrs
.
empty
())
{
throw
std
::
runtime_error
(
"wrong! no device Conv instance found"
);
}
std
::
string
best_conv_name
;
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
int
deviceIndex
=
0
;
hipSetDevice
(
deviceIndex
);
check_hip_error
();
StreamConfig
stream_config
{
nullptr
,
time_kernel
};
hipStreamCreate
(
&
stream_config
.
stream_id_
);
check_hip_error
();
// profile device Conv instances
for
(
auto
&
conv_ptr
:
conv_ptrs
)
{
auto
argument_ptr
=
conv_ptr
.
MakeArgumentPointer
(
static_cast
<
void
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
void
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
void
*>
(
out_device_buf
.
GetDeviceBuffer
()),
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
auto
invoker_ptr
=
conv_ptr
.
MakeInvokerPointer
();
if
(
conv_ptr
.
IsSupportedArgument
(
argument_ptr
.
get
()))
{
std
::
string
conv_name
=
conv_ptr
.
GetTypeString
();
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
stream_config
);
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
;
std
::
size_t
num_btype
=
sizeof
(
InDataType
)
*
(
N
*
C
*
Hi
*
Wi
)
+
sizeof
(
WeiDataType
)
*
(
K
*
C
*
Y
*
X
)
+
sizeof
(
OutDataType
)
*
(
N
*
K
*
Ho
*
Wo
);
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
conv_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
{
best_conv_name
=
conv_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
}
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_conv_name
<<
std
::
endl
;
}
}
// namespace app
}
// namespace ck
test/conv2d_bwd_weight/conv2d_bwd_weight.cpp
View file @
68886f7d
...
...
@@ -28,10 +28,10 @@ int test_self()
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -52,10 +52,10 @@ int test_self()
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -72,8 +72,8 @@ int test_self()
}
int
main
(
int
argc
,
char
*
argv
[])
{
int
data_type
=
0
;
int
init_method
=
0
;
int
data_type
=
1
;
int
init_method
=
1
;
// Conv shape
ck
::
index_t
N
=
128
;
...
...
@@ -155,10 +155,10 @@ int main(int argc, char* argv[])
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
1
,
true
,
// do_verification
init_method
,
0
,
1
,
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -180,10 +180,10 @@ int main(int argc, char* argv[])
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
1
,
true
,
// do_verification
init_method
,
0
,
1
,
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
test/convnd_bwd_data/convnd_bwd_data.cpp
View file @
68886f7d
...
...
@@ -27,10 +27,10 @@ int main()
ck
::
tensor_layout
::
convolution
::
NWC
,
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
tensor_layout
::
convolution
::
NWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -50,10 +50,10 @@ int main()
ck
::
tensor_layout
::
convolution
::
NWC
,
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
tensor_layout
::
convolution
::
NWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -73,10 +73,10 @@ int main()
ck
::
tensor_layout
::
convolution
::
NWC
,
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
tensor_layout
::
convolution
::
NWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -96,10 +96,10 @@ int main()
ck
::
tensor_layout
::
convolution
::
NWC
,
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
tensor_layout
::
convolution
::
NWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -128,10 +128,10 @@ int main()
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -151,10 +151,10 @@ int main()
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -174,10 +174,10 @@ int main()
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -197,10 +197,10 @@ int main()
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -232,10 +232,10 @@ int main()
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
,
ck
::
tensor_layout
::
convolution
::
NDHWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -255,10 +255,10 @@ int main()
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
,
ck
::
tensor_layout
::
convolution
::
NDHWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -278,10 +278,10 @@ int main()
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
,
ck
::
tensor_layout
::
convolution
::
NDHWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -301,10 +301,10 @@ int main()
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
,
ck
::
tensor_layout
::
convolution
::
NDHWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
test/gemm/CMakeLists.txt
View file @
68886f7d
add_test_executable
(
test_gemm_fp32 gemm_fp32.cpp
)
target_link_libraries
(
test_gemm_fp32 PRIVATE host_tensor
)
target_link_libraries
(
test_gemm_fp32 PRIVATE device_gemm_instance
)
# GEMM XDL
add_test_executable
(
test_gemm_xdl_fp32 gemm_xdl_fp32.cpp
)
target_link_libraries
(
test_gemm_xdl_fp32 PRIVATE host_tensor
)
target_link_libraries
(
test_gemm_xdl_fp32 PRIVATE device_gemm_instance
)
add_test_executable
(
test_gemm_fp16 gemm_fp16.cpp
)
target_link_libraries
(
test_gemm_fp16 PRIVATE host_tensor
)
target_link_libraries
(
test_gemm_fp16 PRIVATE device_gemm_instance
)
add_test_executable
(
test_gemm_
xdl_
fp16 gemm_
xdl_
fp16.cpp
)
target_link_libraries
(
test_gemm_
xdl_
fp16 PRIVATE host_tensor
)
target_link_libraries
(
test_gemm_
xdl_
fp16 PRIVATE device_gemm_instance
)
add_test_executable
(
test_gemm_bf16 gemm_bf16.cpp
)
target_link_libraries
(
test_gemm_bf16 PRIVATE host_tensor
)
target_link_libraries
(
test_gemm_bf16 PRIVATE device_gemm_instance
)
add_test_executable
(
test_gemm_
xdl_
bf16 gemm_
xdl_
bf16.cpp
)
target_link_libraries
(
test_gemm_
xdl_
bf16 PRIVATE host_tensor
)
target_link_libraries
(
test_gemm_
xdl_
bf16 PRIVATE device_gemm_instance
)
add_test_executable
(
test_gemm_int8 gemm_int8.cpp
)
target_link_libraries
(
test_gemm_int8 PRIVATE host_tensor
)
target_link_libraries
(
test_gemm_int8 PRIVATE device_gemm_instance
)
add_test_executable
(
test_gemm_xdl_int8 gemm_xdl_int8.cpp
)
target_link_libraries
(
test_gemm_xdl_int8 PRIVATE host_tensor
)
target_link_libraries
(
test_gemm_xdl_int8 PRIVATE device_gemm_instance
)
# GEMM DL
add_test_executable
(
test_gemm_dl_fp32 gemm_dl_fp32.cpp
)
target_link_libraries
(
test_gemm_dl_fp32 PRIVATE host_tensor
)
target_link_libraries
(
test_gemm_dl_fp32 PRIVATE device_gemm_instance
)
add_test_executable
(
test_gemm_dl_fp16 gemm_dl_fp16.cpp
)
target_link_libraries
(
test_gemm_dl_fp16 PRIVATE host_tensor
)
target_link_libraries
(
test_gemm_dl_fp16 PRIVATE device_gemm_instance
)
add_test_executable
(
test_gemm_dl_int8 gemm_dl_int8.cpp
)
target_link_libraries
(
test_gemm_dl_int8 PRIVATE host_tensor
)
TArget_link_libraries
(
test_gemm_dl_int8 PRIVATE device_gemm_instance
)
test/gemm/gemm_
int8
.cpp
→
test/gemm/gemm_
dl_fp16
.cpp
View file @
68886f7d
#include <algorithm>
#include <cstdlib>
#include <half.hpp>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "gemm_util.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
void
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
int
main
()
{
using
ADataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
CDataType
=
int8_t
;
using
RowMajor
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
std
::
vector
<
DeviceGemmNoOpPtr
>
gemmPtrs
;
bool
res
=
true
;
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
ColumnMajor
,
RowMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
RowMajor
,
RowMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
RowMajor
,
ColumnMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
std
::
cout
<<
"TestGemm ..... "
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
res
?
0
:
1
;
}
#include <algorithm>
#include <cstdlib>
#include <half.hpp>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "../gemm/gemm_util.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_gemm_dl.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
void
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
int
main
()
{
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
RowMajor
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
bool
res
=
true
;
std
::
vector
<
DeviceGemmNoOpPtr
>
gemmPtrs
;
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
RowMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
RowMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
ColumnMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
std
::
cout
<<
"TestGemm ..... "
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
res
?
0
:
1
;
}
test/gemm/gemm_dl_fp32.cpp
0 → 100644
View file @
68886f7d
#include <algorithm>
#include <cstdlib>
#include <half.hpp>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "../gemm/gemm_util.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_gemm_dl.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
void
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
int
main
()
{
using
ADataType
=
float
;
using
BDataType
=
float
;
using
CDataType
=
float
;
using
AccDataType
=
float
;
using
RowMajor
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
bool
res
=
true
;
std
::
vector
<
DeviceGemmNoOpPtr
>
gemmPtrs
;
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
RowMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
RowMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
ColumnMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
std
::
cout
<<
"TestGemm ..... "
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
res
?
0
:
1
;
}
test/gemm/gemm_dl_int8.cpp
0 → 100644
View file @
68886f7d
#include <algorithm>
#include <cstdlib>
#include <half.hpp>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "../gemm/gemm_util.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_gemm_dl.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
void
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
int
main
()
{
using
ADataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
CDataType
=
int8_t
;
using
AccDataType
=
int
;
using
RowMajor
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
bool
res
=
true
;
std
::
vector
<
DeviceGemmNoOpPtr
>
gemmPtrs
;
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
RowMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
RowMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
ColumnMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
std
::
cout
<<
"TestGemm ..... "
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
res
?
0
:
1
;
}
test/gemm/gemm_util.hpp
View file @
68886f7d
...
...
@@ -60,7 +60,7 @@ template <typename DeviceGemmPtr_,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
void
RunDeviceGEMM
(
DeviceGemmPtr_
&
gemmPtr
,
bool
RunDeviceGEMM
(
DeviceGemmPtr_
&
gemmPtr
,
const
ck
::
gemm_util
::
GemmParams
&
params
,
const
Tensor
<
ADataType
>&
A
,
const
Tensor
<
BDataType
>&
B
,
...
...
@@ -73,9 +73,6 @@ void RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
B
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
C
.
mDesc
.
GetElementSpace
());
a_m_k_device_buf
.
ToDevice
(
A
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
B
.
mData
.
data
());
auto
invoker_ptr
=
gemmPtr
->
MakeInvokerPointer
();
auto
argument_ptr
=
gemmPtr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
...
...
@@ -91,21 +88,30 @@ void RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
b_element_op
,
c_element_op
);
if
(
!
gemmPtr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
if
(
gemmPtr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
throw
std
::
runtime_error
(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
);
a_m_k_device_buf
.
ToDevice
(
A
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
B
.
mData
.
data
());
invoker_ptr
->
Run
(
argument_ptr
.
get
());
c_m_n_device_buf
.
FromDevice
(
C
.
mData
.
data
());
return
true
;
}
else
{
std
::
cout
<<
"device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
<<
std
::
endl
;
invoker_ptr
->
Run
(
argument_ptr
.
get
())
;
c_m_n_device_buf
.
FromDevice
(
C
.
mData
.
data
());
return
false
;
}
}
template
<
typename
DeviceGemmPtr_
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
...
...
@@ -181,6 +187,7 @@ struct TestGemm
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
...
...
@@ -188,28 +195,40 @@ struct TestGemm
a
,
b
,
c_host
,
a_element_op
,
b_element_op
,
c_element_op
);
// Act
ck
::
gemm_util
::
RunDeviceGEMM
(
bool
is_supported
=
ck
::
gemm_util
::
RunDeviceGEMM
(
gemmPtr
,
params
,
a
,
b
,
c_device
,
a_element_op
,
b_element_op
,
c_element_op
);
// Assert
bool
res
=
false
;
if
(
std
::
is_same
<
CDataType
,
float
>::
value
)
if
(
is_supported
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
// Assert
bool
res
=
false
;
if
(
std
::
is_same
<
CDataType
,
float
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
if
(
std
::
is_same
<
CDataType
,
ck
::
half_t
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
if
(
std
::
is_same
<
CDataType
,
int8_t
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
if
(
std
::
is_same
<
CDataType
,
double
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
return
res
;
}
else
if
(
std
::
is_same
<
CDataType
,
ck
::
half_t
>::
value
)
else
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
true
;
}
else
if
(
std
::
is_same
<
CDataType
,
int8_t
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
return
res
;
}
};
...
...
@@ -299,6 +318,7 @@ struct TestGemmBF16
// use fp32 host kernel to verify bf16 device kernel
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
float
,
float
,
float
,
float
,
AElementwiseOperation
,
...
...
test/gemm/gemm_bf16.cpp
→
test/gemm/gemm_
xdl_
bf16.cpp
View file @
68886f7d
File moved
test/gemm/gemm_fp16.cpp
→
test/gemm/gemm_
xdl_
fp16.cpp
View file @
68886f7d
...
...
@@ -52,9 +52,10 @@ void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
int
main
()
{
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
RowMajor
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
...
@@ -74,6 +75,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
RowMajor
,
RowMajor
,
...
...
@@ -96,6 +98,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
...
...
@@ -118,6 +121,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
RowMajor
,
RowMajor
,
...
...
@@ -142,6 +146,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
ColumnMajor
,
RowMajor
,
...
...
test/gemm/gemm_fp32.cpp
→
test/gemm/gemm_
xdl_
fp32.cpp
View file @
68886f7d
...
...
@@ -53,9 +53,10 @@ void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(std::vector<De
int
main
()
{
using
ADataType
=
float
;
using
BDataType
=
float
;
using
CDataType
=
float
;
using
ADataType
=
float
;
using
BDataType
=
float
;
using
CDataType
=
float
;
using
AccDataType
=
float
;
using
RowMajor
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
...
@@ -75,6 +76,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
RowMajor
,
RowMajor
,
...
...
@@ -97,6 +99,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
...
...
@@ -119,6 +122,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
RowMajor
,
RowMajor
,
...
...
@@ -141,6 +145,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
ColumnMajor
,
RowMajor
,
...
...
Prev
1
…
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment