Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6e3cf8b0
Commit
6e3cf8b0
authored
May 24, 2022
by
Jing Zhang
Browse files
merge develop
parents
4ad62d7f
ba58a93f
Changes
177
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
531 additions
and
123 deletions
+531
-123
profiler/src/profile_gemm_bias_2d.cpp
profiler/src/profile_gemm_bias_2d.cpp
+11
-11
profiler/src/profile_gemm_bias_relu.cpp
profiler/src/profile_gemm_bias_relu.cpp
+7
-7
profiler/src/profile_gemm_bias_relu_add.cpp
profiler/src/profile_gemm_bias_relu_add.cpp
+7
-7
profiler/src/profile_gemm_reduce.cpp
profiler/src/profile_gemm_reduce.cpp
+7
-7
profiler/src/profile_grouped_gemm.cpp
profiler/src/profile_grouped_gemm.cpp
+7
-7
profiler/src/profile_reduce.cpp
profiler/src/profile_reduce.cpp
+10
-10
test/CMakeLists.txt
test/CMakeLists.txt
+5
-1
test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp
test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp
+4
-4
test/block_to_ctile_map/CMakeLists.txt
test/block_to_ctile_map/CMakeLists.txt
+1
-0
test/block_to_ctile_map/test_block_to_ctile_map.cpp
test/block_to_ctile_map/test_block_to_ctile_map.cpp
+100
-0
test/client_app/CMakeLists.txt
test/client_app/CMakeLists.txt
+11
-0
test/client_app/client_app.cpp
test/client_app/client_app.cpp
+77
-0
test/client_app/client_app_impl.hpp
test/client_app/client_app_impl.hpp
+214
-0
test/conv2d_bwd_weight/conv2d_bwd_weight.cpp
test/conv2d_bwd_weight/conv2d_bwd_weight.cpp
+16
-16
test/convnd_bwd_data/convnd_bwd_data.cpp
test/convnd_bwd_data/convnd_bwd_data.cpp
+48
-48
test/gemm_reduce/gemm_reduce_fp16.cpp
test/gemm_reduce/gemm_reduce_fp16.cpp
+4
-4
test/gemm_split_k/gemm_split_k.cpp
test/gemm_split_k/gemm_split_k.cpp
+2
-1
No files found.
profiler/src/profile_gemm_bias_2d.cpp
View file @
6e3cf8b0
...
...
@@ -36,8 +36,8 @@ int profile_gemm_bias_2d(int argc, char* argv[])
printf
(
" 3: A[k, m] * B[n, k] = C[m, n])
\n
"
);
printf
(
"arg4: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg
8
: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7:
run
kernel
# of times (>1
)
\n
"
);
printf
(
"arg
6
: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7:
time
kernel
(0=n0, 1=yes
)
\n
"
);
printf
(
"arg8 to 13: M, N, K, StrideA, StrideB, StrideC
\n
"
);
printf
(
"arg14: alpha
\n
"
);
printf
(
"arg15: beta
\n
"
);
...
...
@@ -50,7 +50,7 @@ int profile_gemm_bias_2d(int argc, char* argv[])
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
7
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
7
]);
const
int
M
=
std
::
stoi
(
argv
[
8
]);
const
int
N
=
std
::
stoi
(
argv
[
9
]);
...
...
@@ -76,7 +76,7 @@ int profile_gemm_bias_2d(int argc, char* argv[])
do_verification
,
init_method
,
do_log
,
nrepeat
,
time_kernel
,
M
,
N
,
K
,
...
...
@@ -99,7 +99,7 @@ int profile_gemm_bias_2d(int argc, char* argv[])
do_verification
,
init_method
,
do_log
,
nrepeat
,
time_kernel
,
M
,
N
,
K
,
...
...
@@ -122,7 +122,7 @@ int profile_gemm_bias_2d(int argc, char* argv[])
do_verification
,
init_method
,
do_log
,
nrepeat
,
time_kernel
,
M
,
N
,
K
,
...
...
@@ -145,7 +145,7 @@ int profile_gemm_bias_2d(int argc, char* argv[])
do_verification
,
init_method
,
do_log
,
nrepeat
,
time_kernel
,
M
,
N
,
K
,
...
...
@@ -168,7 +168,7 @@ int profile_gemm_bias_2d(int argc, char* argv[])
do_verification
,
init_method
,
do_log
,
nrepeat
,
time_kernel
,
M
,
N
,
K
,
...
...
@@ -191,7 +191,7 @@ int profile_gemm_bias_2d(int argc, char* argv[])
do_verification
,
init_method
,
do_log
,
nrepeat
,
time_kernel
,
M
,
N
,
K
,
...
...
@@ -214,7 +214,7 @@ int profile_gemm_bias_2d(int argc, char* argv[])
do_verification
,
init_method
,
do_log
,
nrepeat
,
time_kernel
,
M
,
N
,
K
,
...
...
@@ -237,7 +237,7 @@ int profile_gemm_bias_2d(int argc, char* argv[])
do_verification
,
init_method
,
do_log
,
nrepeat
,
time_kernel
,
M
,
N
,
K
,
...
...
profiler/src/profile_gemm_bias_relu.cpp
View file @
6e3cf8b0
...
...
@@ -36,8 +36,8 @@ int profile_gemm_bias_relu(int argc, char* argv[])
printf
(
" 3: A[k, m] * B[n, k] = C[m, n])
\n
"
);
printf
(
"arg4: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg
8
: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7:
run
kernel
# of times (>1
)
\n
"
);
printf
(
"arg
6
: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7:
time
kernel
(0=n0, 1=yes
)
\n
"
);
printf
(
"arg8 to 13: M, N, K, StrideA, StrideB, StrideC
\n
"
);
printf
(
"arg14: split k into mulitiple batch
\n
"
);
exit
(
1
);
...
...
@@ -48,7 +48,7 @@ int profile_gemm_bias_relu(int argc, char* argv[])
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
7
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
7
]);
const
int
M
=
std
::
stoi
(
argv
[
8
]);
const
int
N
=
std
::
stoi
(
argv
[
9
]);
...
...
@@ -69,7 +69,7 @@ int profile_gemm_bias_relu(int argc, char* argv[])
do_verification
,
init_method
,
do_log
,
nrepeat
,
time_kernel
,
M
,
N
,
K
,
...
...
@@ -88,7 +88,7 @@ int profile_gemm_bias_relu(int argc, char* argv[])
do_verification
,
init_method
,
do_log
,
nrepeat
,
time_kernel
,
M
,
N
,
K
,
...
...
@@ -107,7 +107,7 @@ int profile_gemm_bias_relu(int argc, char* argv[])
do_verification
,
init_method
,
do_log
,
nrepeat
,
time_kernel
,
M
,
N
,
K
,
...
...
@@ -126,7 +126,7 @@ int profile_gemm_bias_relu(int argc, char* argv[])
do_verification
,
init_method
,
do_log
,
nrepeat
,
time_kernel
,
M
,
N
,
K
,
...
...
profiler/src/profile_gemm_bias_relu_add.cpp
View file @
6e3cf8b0
...
...
@@ -36,8 +36,8 @@ int profile_gemm_bias_relu_add(int argc, char* argv[])
printf
(
" 3: A[k, m] * B[n, k] = C[m, n])
\n
"
);
printf
(
"arg4: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg
8
: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7:
run
kernel
# of times (>1
)
\n
"
);
printf
(
"arg
6
: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7:
time
kernel
(0=n0, 1=yes
)
\n
"
);
printf
(
"arg8 to 14: M, N, K, StrideA, StrideB, StrideC, StrideC1
\n
"
);
printf
(
"arg15: split k into mulitiple batch
\n
"
);
exit
(
1
);
...
...
@@ -48,7 +48,7 @@ int profile_gemm_bias_relu_add(int argc, char* argv[])
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
7
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
7
]);
const
int
M
=
std
::
stoi
(
argv
[
8
]);
const
int
N
=
std
::
stoi
(
argv
[
9
]);
...
...
@@ -70,7 +70,7 @@ int profile_gemm_bias_relu_add(int argc, char* argv[])
do_verification
,
init_method
,
do_log
,
nrepeat
,
time_kernel
,
M
,
N
,
K
,
...
...
@@ -90,7 +90,7 @@ int profile_gemm_bias_relu_add(int argc, char* argv[])
do_verification
,
init_method
,
do_log
,
nrepeat
,
time_kernel
,
M
,
N
,
K
,
...
...
@@ -110,7 +110,7 @@ int profile_gemm_bias_relu_add(int argc, char* argv[])
do_verification
,
init_method
,
do_log
,
nrepeat
,
time_kernel
,
M
,
N
,
K
,
...
...
@@ -130,7 +130,7 @@ int profile_gemm_bias_relu_add(int argc, char* argv[])
do_verification
,
init_method
,
do_log
,
nrepeat
,
time_kernel
,
M
,
N
,
K
,
...
...
profiler/src/profile_gemm_reduce.cpp
View file @
6e3cf8b0
...
...
@@ -32,8 +32,8 @@ int profile_gemm_reduce(int argc, char* argv[])
printf
(
" 3: A[k, m] * B[n, k] = C[m, n])
\n
"
);
printf
(
"arg4: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg
8
: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7:
run
kernel
# of times (>1
)
\n
"
);
printf
(
"arg
6
: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7:
time
kernel
(0=n0, 1=yes
)
\n
"
);
printf
(
"arg8 to 13: M, N, K, StrideA, StrideB, StrideC
\n
"
);
printf
(
"arg14: split k into mulitiple batch
\n
"
);
exit
(
1
);
...
...
@@ -44,7 +44,7 @@ int profile_gemm_reduce(int argc, char* argv[])
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
7
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
7
]);
const
int
M
=
std
::
stoi
(
argv
[
8
]);
const
int
N
=
std
::
stoi
(
argv
[
9
]);
...
...
@@ -66,7 +66,7 @@ int profile_gemm_reduce(int argc, char* argv[])
do_verification
,
init_method
,
do_log
,
nrepeat
,
time_kernel
,
M
,
N
,
K
,
...
...
@@ -87,7 +87,7 @@ int profile_gemm_reduce(int argc, char* argv[])
do_verification
,
init_method
,
do_log
,
nrepeat
,
time_kernel
,
M
,
N
,
K
,
...
...
@@ -108,7 +108,7 @@ int profile_gemm_reduce(int argc, char* argv[])
do_verification
,
init_method
,
do_log
,
nrepeat
,
time_kernel
,
M
,
N
,
K
,
...
...
@@ -129,7 +129,7 @@ int profile_gemm_reduce(int argc, char* argv[])
do_verification
,
init_method
,
do_log
,
nrepeat
,
time_kernel
,
M
,
N
,
K
,
...
...
profiler/src/profile_grouped_gemm.cpp
View file @
6e3cf8b0
...
...
@@ -54,8 +54,8 @@ int profile_grouped_gemm(int argc, char* argv[])
printf
(
" 3: A[k, m] * B[n, k] = C[m, n])
\n
"
);
printf
(
"arg4: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg
8
: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7:
run
kernel
# of times (>1
)
\n
"
);
printf
(
"arg
6
: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7:
time
kernel
(0=n0, 1=yes
)
\n
"
);
printf
(
"arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)
\n
"
);
exit
(
1
);
...
...
@@ -66,7 +66,7 @@ int profile_grouped_gemm(int argc, char* argv[])
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
const
int
init_method
=
std
::
stoi
(
argv
[
5
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
7
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
7
]);
const
auto
Ms
=
argToIntArray
(
argv
[
8
]);
const
auto
Ns
=
argToIntArray
(
argv
[
9
]);
...
...
@@ -86,7 +86,7 @@ int profile_grouped_gemm(int argc, char* argv[])
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
time_kernel
,
Ms
,
Ns
,
Ks
,
...
...
@@ -104,7 +104,7 @@ int profile_grouped_gemm(int argc, char* argv[])
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
time_kernel
,
Ms
,
Ns
,
Ks
,
...
...
@@ -122,7 +122,7 @@ int profile_grouped_gemm(int argc, char* argv[])
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
time_kernel
,
Ms
,
Ns
,
Ks
,
...
...
@@ -140,7 +140,7 @@ int profile_grouped_gemm(int argc, char* argv[])
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
time_kernel
,
Ms
,
Ns
,
Ks
,
...
...
profiler/src/profile_reduce.cpp
View file @
6e3cf8b0
...
...
@@ -144,7 +144,7 @@ class AppArgs
bool
do_dumpout
=
false
;
int
init_method
;
int
nrepeat
;
bool
time_kernel
;
bool
need_indices
=
false
;
...
...
@@ -295,7 +295,7 @@ class AppArgs
throw
std
::
runtime_error
(
"Invalid cmd-line arguments, more argumetns are needed!"
);
init_method
=
std
::
atoi
(
argv
[
optind
++
]);
nrepeat
=
std
::
atoi
(
argv
[
optind
]);
time_kernel
=
std
::
atoi
(
argv
[
optind
]);
if
(
scales
.
empty
())
{
...
...
@@ -354,7 +354,7 @@ int profile_reduce(int argc, char* argv[])
args
.
init_method
,
args
.
do_log
,
args
.
do_dumpout
,
args
.
nrepeat
,
args
.
time_kernel
,
args
.
inLengths
,
args
.
reduceDims
,
args
.
reduceOp
,
...
...
@@ -369,7 +369,7 @@ int profile_reduce(int argc, char* argv[])
args
.
init_method
,
args
.
do_log
,
args
.
do_dumpout
,
args
.
nrepeat
,
args
.
time_kernel
,
args
.
inLengths
,
args
.
reduceDims
,
args
.
reduceOp
,
...
...
@@ -387,7 +387,7 @@ int profile_reduce(int argc, char* argv[])
args
.
init_method
,
args
.
do_log
,
args
.
do_dumpout
,
args
.
nrepeat
,
args
.
time_kernel
,
args
.
inLengths
,
args
.
reduceDims
,
args
.
reduceOp
,
...
...
@@ -414,7 +414,7 @@ int profile_reduce(int argc, char* argv[])
args
.
init_method
,
args
.
do_log
,
args
.
do_dumpout
,
args
.
nrepeat
,
args
.
time_kernel
,
args
.
inLengths
,
args
.
reduceDims
,
args
.
reduceOp
,
...
...
@@ -429,7 +429,7 @@ int profile_reduce(int argc, char* argv[])
args
.
init_method
,
args
.
do_log
,
args
.
do_dumpout
,
args
.
nrepeat
,
args
.
time_kernel
,
args
.
inLengths
,
args
.
reduceDims
,
args
.
reduceOp
,
...
...
@@ -454,7 +454,7 @@ int profile_reduce(int argc, char* argv[])
args
.
init_method
,
args
.
do_log
,
args
.
do_dumpout
,
args
.
nrepeat
,
args
.
time_kernel
,
args
.
inLengths
,
args
.
reduceDims
,
args
.
reduceOp
,
...
...
@@ -471,7 +471,7 @@ int profile_reduce(int argc, char* argv[])
args
.
init_method
,
args
.
do_log
,
args
.
do_dumpout
,
args
.
nrepeat
,
args
.
time_kernel
,
args
.
inLengths
,
args
.
reduceDims
,
args
.
reduceOp
,
...
...
@@ -486,7 +486,7 @@ int profile_reduce(int argc, char* argv[])
args
.
init_method
,
args
.
do_log
,
args
.
do_dumpout
,
args
.
nrepeat
,
args
.
time_kernel
,
args
.
inLengths
,
args
.
reduceDims
,
args
.
reduceOp
,
...
...
test/CMakeLists.txt
View file @
6e3cf8b0
...
...
@@ -22,7 +22,8 @@ include_directories(BEFORE
${
PROJECT_SOURCE_DIR
}
/external/include/half
)
add_custom_target
(
check COMMAND
${
CMAKE_CTEST_COMMAND
}
--output-on-failure -C
${
CMAKE_CFG_INTDIR
}
)
include
(
googletest
)
add_custom_target
(
tests
)
...
...
@@ -61,3 +62,6 @@ add_subdirectory(grouped_gemm)
add_subdirectory
(
convnd_fwd
)
add_subdirectory
(
reduce
)
add_subdirectory
(
conv2d_bwd_weight
)
add_subdirectory
(
convnd_bwd_data
)
add_subdirectory
(
block_to_ctile_map
)
# DONOT add client_app, that is tested via CI independently
test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp
View file @
6e3cf8b0
...
...
@@ -22,7 +22,7 @@ int main()
Row
,
Row
,
Row
>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
N
,
N
,
BatchCount
);
true
,
1
,
false
,
false
,
M
,
N
,
K
,
K
,
N
,
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_reduce_impl
<
ck
::
half_t
,
ck
::
half_t
,
...
...
@@ -31,7 +31,7 @@ int main()
Row
,
Col
,
Row
>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
K
,
N
,
BatchCount
);
true
,
1
,
false
,
false
,
M
,
N
,
K
,
K
,
K
,
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_reduce_impl
<
ck
::
half_t
,
ck
::
half_t
,
...
...
@@ -40,7 +40,7 @@ int main()
Col
,
Row
,
Row
>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
N
,
N
,
BatchCount
);
true
,
1
,
false
,
false
,
M
,
N
,
K
,
M
,
N
,
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_reduce_impl
<
ck
::
half_t
,
ck
::
half_t
,
...
...
@@ -49,7 +49,7 @@ int main()
Col
,
Col
,
Row
>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
K
,
N
,
BatchCount
);
true
,
1
,
false
,
false
,
M
,
N
,
K
,
M
,
K
,
N
,
BatchCount
);
if
(
pass
)
{
...
...
test/block_to_ctile_map/CMakeLists.txt
0 → 100644
View file @
6e3cf8b0
add_gtest_executable
(
test_block_to_ctile_map test_block_to_ctile_map.cpp
)
\ No newline at end of file
test/block_to_ctile_map/test_block_to_ctile_map.cpp
0 → 100644
View file @
6e3cf8b0
#include <ck/config.hpp>
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "gtest/gtest.h"
#include <iostream>
#include <vector>
using
namespace
ck
;
static
auto
I0
=
Number
<
0
>
{};
static
auto
I1
=
Number
<
1
>
{};
TEST
(
BlockToCTileMap
,
TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck1
)
{
const
index_t
M
=
384
;
const
index_t
N
=
384
;
const
index_t
MPerBlock
=
128
;
const
index_t
NPerBlock
=
128
;
const
index_t
MBlock
=
M
/
MPerBlock
;
const
index_t
NBlock
=
N
/
NPerBlock
;
const
index_t
M01
=
4
;
const
index_t
N01
=
4
;
auto
c_grid_desc_m_n
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
I1
));
printf
(
"(M, N, MPerBlock, NPerBlock, M01, N01) = (%d, %d, %d, %d, %d, %d)
\n
"
,
M
,
N
,
MPerBlock
,
NPerBlock
,
M01
,
N01
);
BlockToCTileMap_M00_N00_M01_N01
<
MPerBlock
,
NPerBlock
,
decltype
(
c_grid_desc_m_n
),
true
>
tile_map
(
c_grid_desc_m_n
,
M01
,
N01
);
EXPECT_TRUE
(
tile_map
.
CheckValidity
(
c_grid_desc_m_n
)
==
true
);
EXPECT_TRUE
(
tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
)
==
16
);
// clang-format off
std
::
vector
<
std
::
vector
<
int
>>
expected
=
{
{
0
,
0
,
1
},
{
0
,
1
,
1
},
{
0
,
2
,
1
},
{
0
,
3
,
0
},
{
1
,
0
,
1
},
{
1
,
1
,
1
},
{
1
,
2
,
1
},
{
1
,
3
,
0
},
{
2
,
0
,
1
},
{
2
,
1
,
1
},
{
2
,
2
,
1
},
{
2
,
3
,
0
},
{
3
,
0
,
0
},
{
3
,
1
,
0
},
{
3
,
2
,
0
},
{
3
,
3
,
0
}
};
// clang-format on
for
(
index_t
i
=
0
;
i
<
tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
i
++
)
{
auto
m0n0_idx
=
tile_map
.
CalculateBottomIndex
(
make_multi_index
(
i
));
std
::
cout
<<
"block_1d_id = "
<<
i
<<
", m0, n0 = "
<<
m0n0_idx
[
I0
]
<<
", "
<<
m0n0_idx
[
I1
];
std
::
cout
<<
", valid = "
<<
tile_map
.
ValidCTileIndex
(
m0n0_idx
,
make_tuple
(
MBlock
,
NBlock
))
<<
std
::
endl
;
bool
equal
=
expected
[
i
]
==
std
::
vector
<
int
>
{
m0n0_idx
[
I0
],
m0n0_idx
[
I1
],
tile_map
.
ValidCTileIndex
(
m0n0_idx
,
make_tuple
(
MBlock
,
NBlock
))};
EXPECT_TRUE
(
equal
);
}
}
TEST
(
BlockToCTileMap
,
TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck0
)
{
const
index_t
M
=
384
;
const
index_t
N
=
384
;
const
index_t
MPerBlock
=
128
;
const
index_t
NPerBlock
=
128
;
// const index_t MBlock = M / MPerBlock;
// const index_t NBlock = N / NPerBlock;
const
index_t
M01
=
4
;
const
index_t
N01
=
4
;
auto
c_grid_desc_m_n
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
I1
));
printf
(
"(M, N, MPerBlock, NPerBlock, M01, N01) = (%d, %d, %d, %d, %d, %d)
\n
"
,
M
,
N
,
MPerBlock
,
NPerBlock
,
M01
,
N01
);
BlockToCTileMap_M00_N00_M01_N01
<
MPerBlock
,
NPerBlock
,
decltype
(
c_grid_desc_m_n
),
false
>
tile_map
(
c_grid_desc_m_n
,
M01
,
N01
);
EXPECT_TRUE
(
tile_map
.
CheckValidity
(
c_grid_desc_m_n
)
==
false
);
}
test/client_app/CMakeLists.txt
0 → 100644
View file @
6e3cf8b0
cmake_minimum_required
(
VERSION 3.15
)
project
(
ck_app
)
add_compile_options
(
-std=c++14
)
find_package
(
composable_kernel 1.0.0 COMPONENTS device_operations host_tensor
)
find_package
(
hip REQUIRED PATHS /opt/rocm
)
message
(
STATUS
"Build with HIP
${
hip_VERSION
}
"
)
add_executable
(
test_client_app client_app.cpp
)
target_link_libraries
(
test_client_app PRIVATE composable_kernel::device_operations composable_kernel::host_tensor hip::host
)
test/client_app/client_app.cpp
0 → 100644
View file @
6e3cf8b0
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include <vector>
#include "client_app_impl.hpp"
int
main
(
int
argc
,
char
*
argv
[])
{
if
(
argc
!=
25
)
{
printf
(
"arg1: tensor operation (conv_fwd: ForwardConvolution)
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16)
\n
"
);
printf
(
"arg3: input tensor layout (0: NCHW; 1: NHWC)
\n
"
);
printf
(
"arg4: weight tensor layout (0: KCYX; 1: KYXC)
\n
"
);
printf
(
"arg5: output tensor layout (0: NKHW; 1: NHWK)
\n
"
);
printf
(
"arg6: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg7: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg8: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg9: time kernel (0=n0, 1=yes)
\n
"
);
printf
(
"arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx
\n
"
);
exit
(
1
);
}
const
ConvDataType
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
int
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
int
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
int
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
9
]);
const
ck
::
index_t
N
=
std
::
stoi
(
argv
[
10
]);
const
ck
::
index_t
K
=
std
::
stoi
(
argv
[
11
]);
const
ck
::
index_t
C
=
std
::
stoi
(
argv
[
12
]);
const
ck
::
index_t
Y
=
std
::
stoi
(
argv
[
13
]);
const
ck
::
index_t
X
=
std
::
stoi
(
argv
[
14
]);
const
ck
::
index_t
Hi
=
std
::
stoi
(
argv
[
15
]);
const
ck
::
index_t
Wi
=
std
::
stoi
(
argv
[
16
]);
const
ck
::
index_t
conv_stride_h
=
std
::
stoi
(
argv
[
17
]);
const
ck
::
index_t
conv_stride_w
=
std
::
stoi
(
argv
[
18
]);
const
ck
::
index_t
conv_dilation_h
=
std
::
stoi
(
argv
[
19
]);
const
ck
::
index_t
conv_dilation_w
=
std
::
stoi
(
argv
[
20
]);
const
ck
::
index_t
in_left_pad_h
=
std
::
stoi
(
argv
[
21
]);
const
ck
::
index_t
in_left_pad_w
=
std
::
stoi
(
argv
[
22
]);
const
ck
::
index_t
in_right_pad_h
=
std
::
stoi
(
argv
[
23
]);
const
ck
::
index_t
in_right_pad_w
=
std
::
stoi
(
argv
[
24
]);
const
ck
::
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
const
ck
::
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
const
ck
::
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
YEff
)
/
conv_stride_h
+
1
;
const
ck
::
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
1
;
ck
::
app
::
profile_conv_fwd_impl
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
data_type
,
N
,
K
,
C
,
std
::
vector
<
ck
::
index_t
>
{
Hi
,
Wi
},
std
::
vector
<
ck
::
index_t
>
{
Y
,
X
},
std
::
vector
<
ck
::
index_t
>
{
Ho
,
Wo
},
std
::
vector
<
ck
::
index_t
>
{
conv_stride_h
,
conv_stride_w
},
std
::
vector
<
ck
::
index_t
>
{
conv_dilation_h
,
conv_dilation_w
},
std
::
vector
<
ck
::
index_t
>
{
in_left_pad_h
,
in_left_pad_w
},
std
::
vector
<
ck
::
index_t
>
{
in_right_pad_h
,
in_right_pad_w
});
return
1
;
}
test/client_app/client_app_impl.hpp
0 → 100644
View file @
6e3cf8b0
#pragma once
#include "host_interface.hpp"
enum
ConvDataType
{
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
BF16_BF16_BF16
,
// 2
INT8_INT8_INT8
,
// 3
};
enum
ConvInputLayout
{
NCHW
,
// 0
NHWC
,
// 1
};
enum
ConvWeightLayout
{
KCYX
,
// 0
KYXC
,
// 1
};
enum
ConvOutputLayout
{
NKHW
,
// 0
NHWK
,
// 1
};
void
check_hip_error
(
void
)
{
hipError_t
err
=
hipGetLastError
();
if
(
err
!=
hipSuccess
)
{
std
::
cerr
<<
"Error: "
<<
hipGetErrorString
(
err
)
<<
std
::
endl
;
exit
(
err
);
}
}
std
::
string
getDeviceName
(
int
device
)
{
struct
hipDeviceProp_t
prop
;
hipGetDeviceProperties
(
&
prop
,
device
);
check_hip_error
();
return
std
::
string
(
prop
.
name
);
}
int
getDriver
(
void
)
{
int
driver
;
hipDriverGetVersion
(
&
driver
);
check_hip_error
();
return
driver
;
}
namespace
ck
{
namespace
app
{
struct
DeviceMem
{
DeviceMem
()
=
delete
;
DeviceMem
(
std
::
size_t
mem_size
);
void
*
GetDeviceBuffer
();
void
ToDevice
(
const
void
*
p
);
void
FromDevice
(
void
*
p
);
~
DeviceMem
();
void
*
mpDeviceBuf
;
std
::
size_t
mMemSize
;
};
DeviceMem
::
DeviceMem
(
std
::
size_t
mem_size
)
:
mMemSize
(
mem_size
)
{
hipGetErrorString
(
hipMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
));
}
void
*
DeviceMem
::
GetDeviceBuffer
()
{
return
mpDeviceBuf
;
}
void
DeviceMem
::
ToDevice
(
const
void
*
p
)
{
hipGetErrorString
(
hipMemcpy
(
mpDeviceBuf
,
const_cast
<
void
*>
(
p
),
mMemSize
,
hipMemcpyHostToDevice
));
}
void
DeviceMem
::
FromDevice
(
void
*
p
)
{
hipGetErrorString
(
hipMemcpy
(
p
,
mpDeviceBuf
,
mMemSize
,
hipMemcpyDeviceToHost
));
}
DeviceMem
::~
DeviceMem
()
{
hipGetErrorString
(
hipFree
(
mpDeviceBuf
));
}
void
profile_conv_fwd_impl
(
int
do_verification
,
int
init_method
,
bool
do_log
,
bool
time_kernel
,
ConvDataType
data_type
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
{
const
ck
::
index_t
Y
=
filter_spatial_lengths
[
0
];
const
ck
::
index_t
X
=
filter_spatial_lengths
[
1
];
const
ck
::
index_t
Hi
=
input_spatial_lengths
[
0
];
const
ck
::
index_t
Wi
=
input_spatial_lengths
[
1
];
const
ck
::
index_t
Ho
=
output_spatial_lengths
[
0
];
const
ck
::
index_t
Wo
=
output_spatial_lengths
[
1
];
const
auto
in_sz
=
N
*
C
*
Hi
*
Wi
;
const
auto
wei_sz
=
K
*
C
*
Y
*
X
;
const
auto
out_sz
=
N
*
K
*
Ho
*
Wo
;
using
WeiDataType
=
float
;
using
InDataType
=
float
;
using
OutDataType
=
float
;
app
::
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_sz
);
app
::
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
wei_sz
);
app
::
DeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
out_sz
);
// data is already on device!
// add device Conv instances
std
::
vector
<
DeviceConvFwdPtr_t
>
conv_ptrs
;
if
(
data_type
==
F16_F16_F16
)
{
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t
(
conv_ptrs
);
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t
(
conv_ptrs
);
}
else
if
(
data_type
==
BF16_BF16_BF16
)
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t
(
conv_ptrs
);
else
if
(
data_type
==
F32_F32_F32
)
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t
(
conv_ptrs
);
else
if
(
data_type
==
INT8_INT8_INT8
)
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t
(
conv_ptrs
);
else
throw
std
::
runtime_error
(
"wrong! Invalid data type"
);
if
(
conv_ptrs
.
empty
())
{
throw
std
::
runtime_error
(
"wrong! no device Conv instance found"
);
}
std
::
string
best_conv_name
;
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
int
deviceIndex
=
0
;
hipSetDevice
(
deviceIndex
);
check_hip_error
();
StreamConfig
stream_config
{
nullptr
,
time_kernel
};
hipStreamCreate
(
&
stream_config
.
stream_id_
);
check_hip_error
();
// profile device Conv instances
for
(
auto
&
conv_ptr
:
conv_ptrs
)
{
auto
argument_ptr
=
conv_ptr
.
MakeArgumentPointer
(
static_cast
<
void
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
void
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
void
*>
(
out_device_buf
.
GetDeviceBuffer
()),
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
auto
invoker_ptr
=
conv_ptr
.
MakeInvokerPointer
();
if
(
conv_ptr
.
IsSupportedArgument
(
argument_ptr
.
get
()))
{
std
::
string
conv_name
=
conv_ptr
.
GetTypeString
();
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
stream_config
);
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
;
std
::
size_t
num_btype
=
sizeof
(
InDataType
)
*
(
N
*
C
*
Hi
*
Wi
)
+
sizeof
(
WeiDataType
)
*
(
K
*
C
*
Y
*
X
)
+
sizeof
(
OutDataType
)
*
(
N
*
K
*
Ho
*
Wo
);
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
conv_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
{
best_conv_name
=
conv_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
}
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_conv_name
<<
std
::
endl
;
}
}
// namespace app
}
// namespace ck
test/conv2d_bwd_weight/conv2d_bwd_weight.cpp
View file @
6e3cf8b0
...
...
@@ -28,10 +28,10 @@ int test_self()
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -52,10 +52,10 @@ int test_self()
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -72,8 +72,8 @@ int test_self()
}
int
main
(
int
argc
,
char
*
argv
[])
{
int
data_type
=
0
;
int
init_method
=
0
;
int
data_type
=
1
;
int
init_method
=
1
;
// Conv shape
ck
::
index_t
N
=
128
;
...
...
@@ -155,10 +155,10 @@ int main(int argc, char* argv[])
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
1
,
true
,
// do_verification
init_method
,
0
,
1
,
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -180,10 +180,10 @@ int main(int argc, char* argv[])
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
1
,
true
,
// do_verification
init_method
,
0
,
1
,
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
test/convnd_bwd_data/convnd_bwd_data.cpp
View file @
6e3cf8b0
...
...
@@ -27,10 +27,10 @@ int main()
ck
::
tensor_layout
::
convolution
::
NWC
,
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
tensor_layout
::
convolution
::
NWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -50,10 +50,10 @@ int main()
ck
::
tensor_layout
::
convolution
::
NWC
,
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
tensor_layout
::
convolution
::
NWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -73,10 +73,10 @@ int main()
ck
::
tensor_layout
::
convolution
::
NWC
,
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
tensor_layout
::
convolution
::
NWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -96,10 +96,10 @@ int main()
ck
::
tensor_layout
::
convolution
::
NWC
,
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
tensor_layout
::
convolution
::
NWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -128,10 +128,10 @@ int main()
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -151,10 +151,10 @@ int main()
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -174,10 +174,10 @@ int main()
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -197,10 +197,10 @@ int main()
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -232,10 +232,10 @@ int main()
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
,
ck
::
tensor_layout
::
convolution
::
NDHWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -255,10 +255,10 @@ int main()
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
,
ck
::
tensor_layout
::
convolution
::
NDHWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -278,10 +278,10 @@ int main()
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
,
ck
::
tensor_layout
::
convolution
::
NDHWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
@@ -301,10 +301,10 @@ int main()
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
,
ck
::
tensor_layout
::
convolution
::
NDHWK
>
(
1
,
// do_verification
,
1
,
// init_method
,
0
,
// do_log
,
1
,
// nrepeat,
true
,
// do_verification
1
,
// init_method
false
,
// do_log
false
,
// time_kernel
param
.
N_
,
param
.
K_
,
param
.
C_
,
...
...
test/gemm_reduce/gemm_reduce_fp16.cpp
View file @
6e3cf8b0
...
...
@@ -16,22 +16,22 @@ int main()
pass
=
pass
&&
ck
::
profiler
::
profile_gemm_reduce_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
Row
,
Row
,
Row
>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
N
,
N
);
true
,
1
,
false
,
false
,
M
,
N
,
K
,
K
,
N
,
N
);
pass
=
pass
&&
ck
::
profiler
::
profile_gemm_reduce_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
Row
,
Col
,
Row
>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
K
,
N
);
true
,
1
,
false
,
false
,
M
,
N
,
K
,
K
,
K
,
N
);
pass
=
pass
&&
ck
::
profiler
::
profile_gemm_reduce_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
Col
,
Row
,
Row
>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
N
,
N
);
true
,
1
,
false
,
false
,
M
,
N
,
K
,
M
,
N
,
N
);
pass
=
pass
&&
ck
::
profiler
::
profile_gemm_reduce_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
Col
,
Col
,
Row
>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
K
,
N
);
true
,
1
,
false
,
false
,
M
,
N
,
K
,
M
,
K
,
N
);
if
(
pass
)
{
...
...
test/gemm_split_k/gemm_split_k.cpp
View file @
6e3cf8b0
...
...
@@ -187,9 +187,10 @@ int test_gemm(const gemmArgs& args)
if
(
gemm_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
invoker_ptr
->
Run
(
argument_ptr
.
get
()
,
0
);
invoker_ptr
->
Run
(
argument_ptr
.
get
());
c_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
if
(
!
check_out
(
c_m_n_host_result
,
c_m_n_device_result
))
{
success
=
false
;
...
...
Prev
1
…
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment