Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e00a943e
Commit
e00a943e
authored
May 17, 2022
by
myamlak
Browse files
Merge remote-tracking branch 'origin/develop' into myamlak/cgemm
parents
ffe12e2e
9f71ff48
Changes
162
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
462 additions
and
220 deletions
+462
-220
example/10_conv2d_bwd_data/CMakeLists.txt
example/10_conv2d_bwd_data/CMakeLists.txt
+1
-1
example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp
example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp
+12
-8
example/11_conv2d_bwd_weight/CMakeLists.txt
example/11_conv2d_bwd_weight/CMakeLists.txt
+1
-1
example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp
example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp
+11
-8
example/12_reduce/CMakeLists.txt
example/12_reduce/CMakeLists.txt
+1
-1
example/12_reduce/reduce_blockwise.cpp
example/12_reduce/reduce_blockwise.cpp
+10
-10
example/13_pool2d_fwd/pool2d_fwd.cpp
example/13_pool2d_fwd/pool2d_fwd.cpp
+12
-10
example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp
...quant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp
+8
-8
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
+9
-8
example/16_gemm_reduce/gemm_reduce_xdl_fp16.cpp
example/16_gemm_reduce/gemm_reduce_xdl_fp16.cpp
+26
-32
example/17_convnd_bwd_data_xdl/CMakeLists.txt
example/17_convnd_bwd_data_xdl/CMakeLists.txt
+1
-1
example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp
example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp
+59
-55
example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp
...e/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp
+33
-40
example/19_cgemm/cgemm_xdl_bf16.cpp
example/19_cgemm/cgemm_xdl_bf16.cpp
+6
-6
example/CMakeLists.txt
example/CMakeLists.txt
+11
-2
include/ck/config.hpp
include/ck/config.hpp
+4
-0
include/ck/hip_version.hpp.in
include/ck/hip_version.hpp.in
+0
-28
include/ck/options.hpp.in
include/ck/options.hpp.in
+3
-0
include/ck/stream_config.hpp
include/ck/stream_config.hpp
+10
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+244
-1
No files found.
example/10_conv2d_bwd_data/CMakeLists.txt
View file @
e00a943e
add_example_executable
(
example_conv2d_bwd_data_xdl conv2d_bwd_data_xdl.cpp
)
add_example_executable
(
example_conv2d_bwd_data_xdl conv2d_bwd_data_xdl.cpp
)
target_link_libraries
(
example_conv2d_bwd_data_xdl PRIVATE conv_
fwd_
util
)
target_link_libraries
(
example_conv2d_bwd_data_xdl PRIVATE conv_util
)
example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp
View file @
e00a943e
...
@@ -77,9 +77,9 @@ using ReferenceConvBwdInstance = ck::tensor_operation::host::ReferenceConvBwdDat
...
@@ -77,9 +77,9 @@ using ReferenceConvBwdInstance = ck::tensor_operation::host::ReferenceConvBwdDat
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
bool
do_verification
=
0
;
bool
do_verification
=
true
;
int
init_method
=
0
;
int
init_method
=
1
;
int
nrepeat
=
5
;
bool
time_kernel
=
false
;
// Conv shape
// Conv shape
ck
::
index_t
N
=
128
;
ck
::
index_t
N
=
128
;
...
@@ -102,13 +102,13 @@ int main(int argc, char* argv[])
...
@@ -102,13 +102,13 @@ int main(int argc, char* argv[])
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
19
)
else
if
(
argc
==
19
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
N
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
4
]);
K
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
5
]);
...
@@ -130,7 +130,7 @@ int main(int argc, char* argv[])
...
@@ -130,7 +130,7 @@ int main(int argc, char* argv[])
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3:
run
kernel
# of times (>1
)
\n
"
);
printf
(
"arg3:
time
kernel
(0=n0, 1=yes
)
\n
"
);
printf
(
"arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
printf
(
"arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx
\n
"
);
"RightPx
\n
"
);
exit
(
0
);
exit
(
0
);
...
@@ -214,7 +214,7 @@ int main(int argc, char* argv[])
...
@@ -214,7 +214,7 @@ int main(int argc, char* argv[])
"not support this Conv problem"
);
"not support this Conv problem"
);
}
}
float
ave_time
=
invoker
.
Run
(
argument
,
nrepeat
);
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
}
);
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
;
...
@@ -249,6 +249,10 @@ int main(int argc, char* argv[])
...
@@ -249,6 +249,10 @@ int main(int argc, char* argv[])
in_device_buf
.
FromDevice
(
in_n_c_hi_wi_device_result
.
mData
.
data
());
in_device_buf
.
FromDevice
(
in_n_c_hi_wi_device_result
.
mData
.
data
());
ck
::
utils
::
check_err
(
in_n_c_hi_wi_device_result
.
mData
,
in_n_c_hi_wi_host_result
.
mData
);
return
ck
::
utils
::
check_err
(
in_n_c_hi_wi_device_result
.
mData
,
in_n_c_hi_wi_host_result
.
mData
)
?
0
:
1
;
}
}
return
0
;
}
}
example/11_conv2d_bwd_weight/CMakeLists.txt
View file @
e00a943e
add_example_executable
(
example_conv2d_bwd_weight_xdl conv2d_bwd_weight_xdl.cpp
)
add_example_executable
(
example_conv2d_bwd_weight_xdl conv2d_bwd_weight_xdl.cpp
)
target_link_libraries
(
example_conv2d_bwd_weight_xdl PRIVATE conv_
fwd_
util
)
target_link_libraries
(
example_conv2d_bwd_weight_xdl PRIVATE conv_util
)
example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp
View file @
e00a943e
...
@@ -82,9 +82,9 @@ using ReferenceConvBwdWeightInstance =
...
@@ -82,9 +82,9 @@ using ReferenceConvBwdWeightInstance =
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
bool
do_verification
=
0
;
bool
do_verification
=
true
;
int
init_method
=
0
;
int
init_method
=
1
;
int
nrepeat
=
5
;
bool
time_kernel
=
false
;
int
do_log
=
0
;
int
do_log
=
0
;
int
split_k
=
4
;
int
split_k
=
4
;
...
@@ -109,7 +109,7 @@ int main(int argc, char* argv[])
...
@@ -109,7 +109,7 @@ int main(int argc, char* argv[])
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
do_log
=
std
::
stoi
(
argv
[
4
]);
do_log
=
std
::
stoi
(
argv
[
4
]);
split_k
=
std
::
stoi
(
argv
[
5
]);
split_k
=
std
::
stoi
(
argv
[
5
]);
}
}
...
@@ -117,7 +117,7 @@ int main(int argc, char* argv[])
...
@@ -117,7 +117,7 @@ int main(int argc, char* argv[])
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
do_log
=
std
::
stoi
(
argv
[
4
]);
do_log
=
std
::
stoi
(
argv
[
4
]);
split_k
=
std
::
stoi
(
argv
[
5
]);
split_k
=
std
::
stoi
(
argv
[
5
]);
...
@@ -141,7 +141,7 @@ int main(int argc, char* argv[])
...
@@ -141,7 +141,7 @@ int main(int argc, char* argv[])
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3:
run
kernel
# of times (>1
)
\n
"
);
printf
(
"arg3:
time
kernel
(0=n0, 1=yes
)
\n
"
);
printf
(
"arg4: is show log (0=no, 1=yes)
\n
"
);
printf
(
"arg4: is show log (0=no, 1=yes)
\n
"
);
printf
(
"arg5: split-k
\n
"
);
printf
(
"arg5: split-k
\n
"
);
printf
(
"arg6 to 19: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
printf
(
"arg6 to 19: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
...
@@ -246,7 +246,7 @@ int main(int argc, char* argv[])
...
@@ -246,7 +246,7 @@ int main(int argc, char* argv[])
return
1
;
return
1
;
}
}
float
ave_time
=
invoker
.
Run
(
argument
,
nrepeat
);
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
}
);
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
;
...
@@ -291,6 +291,9 @@ int main(int argc, char* argv[])
...
@@ -291,6 +291,9 @@ int main(int argc, char* argv[])
LogRangeAsType
<
float
>
(
std
::
cout
<<
"wei_host : "
,
wei_k_c_y_x_host_result
.
mData
,
","
)
LogRangeAsType
<
float
>
(
std
::
cout
<<
"wei_host : "
,
wei_k_c_y_x_host_result
.
mData
,
","
)
<<
std
::
endl
;
<<
std
::
endl
;
}
}
ck
::
utils
::
check_err
(
wei_k_c_y_x_device_result
.
mData
,
wei_k_c_y_x_host_result
.
mData
);
return
ck
::
utils
::
check_err
(
wei_k_c_y_x_device_result
.
mData
,
wei_k_c_y_x_host_result
.
mData
)
?
0
:
1
;
}
}
return
0
;
}
}
example/12_reduce/CMakeLists.txt
View file @
e00a943e
add_example_executable
(
example_reduce_blockwise reduce_blockwise.cpp
)
add_example_executable
(
example_reduce_blockwise reduce_blockwise.cpp
-D 16,64,32,960 -v 1 1 10
)
example/12_reduce/reduce_blockwise.cpp
View file @
e00a943e
...
@@ -116,10 +116,9 @@ class SimpleAppArgs
...
@@ -116,10 +116,9 @@ class SimpleAppArgs
std
::
vector
<
size_t
>
inLengths
;
std
::
vector
<
size_t
>
inLengths
;
std
::
vector
<
float
>
scales
;
std
::
vector
<
float
>
scales
;
bool
do_verification
=
false
;
bool
do_verification
=
true
;
int
init_method
=
1
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
int
nrepeat
=
5
;
public:
public:
void
show_usage
(
const
char
*
cmd
)
void
show_usage
(
const
char
*
cmd
)
...
@@ -135,7 +134,7 @@ class SimpleAppArgs
...
@@ -135,7 +134,7 @@ class SimpleAppArgs
std
::
cout
<<
"Arg1 -- init method (0=no init, 1=single integer value, 2=scope integer "
std
::
cout
<<
"Arg1 -- init method (0=no init, 1=single integer value, 2=scope integer "
"value, 3=decimal value)"
"value, 3=decimal value)"
<<
std
::
endl
;
<<
std
::
endl
;
std
::
cout
<<
"Arg2 --
number of repeats to run the kernel
"
<<
std
::
endl
;
std
::
cout
<<
"Arg2 --
time kernel (0=n0, 1=yes)
"
<<
std
::
endl
;
};
};
int
processArgs
(
int
argc
,
char
*
argv
[])
int
processArgs
(
int
argc
,
char
*
argv
[])
...
@@ -182,7 +181,7 @@ class SimpleAppArgs
...
@@ -182,7 +181,7 @@ class SimpleAppArgs
throw
std
::
runtime_error
(
"Invalid cmd-line arguments, more argumetns are needed!"
);
throw
std
::
runtime_error
(
"Invalid cmd-line arguments, more argumetns are needed!"
);
init_method
=
std
::
atoi
(
argv
[
optind
++
]);
init_method
=
std
::
atoi
(
argv
[
optind
++
]);
nrepeat
=
std
::
atoi
(
argv
[
optind
]);
time_kernel
=
std
::
atoi
(
argv
[
optind
]);
if
(
scales
.
empty
())
if
(
scales
.
empty
())
{
{
...
@@ -352,7 +351,7 @@ int main(int argc, char* argv[])
...
@@ -352,7 +351,7 @@ int main(int argc, char* argv[])
auto
invoker_ptr
=
reduce
.
MakeInvokerPointer
();
auto
invoker_ptr
=
reduce
.
MakeInvokerPointer
();
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
args
.
nrepeat
);
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
args
.
time_kernel
}
);
std
::
size_t
num_bytes
=
invariant_total_length
*
reduce_total_length
*
sizeof
(
InDataType
)
+
std
::
size_t
num_bytes
=
invariant_total_length
*
reduce_total_length
*
sizeof
(
InDataType
)
+
invariant_total_length
*
sizeof
(
OutDataType
);
invariant_total_length
*
sizeof
(
OutDataType
);
...
@@ -362,16 +361,17 @@ int main(int argc, char* argv[])
...
@@ -362,16 +361,17 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
avg_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
<<
reduce_name
std
::
cout
<<
"Perf: "
<<
avg_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
<<
reduce_name
<<
std
::
endl
;
<<
std
::
endl
;
bool
pass
=
true
;
if
(
args
.
do_verification
)
if
(
args
.
do_verification
)
{
{
out_dev
.
FromDevice
(
out
.
mData
.
data
());
out_dev
.
FromDevice
(
out
.
mData
.
data
());
ck
::
utils
::
check_err
(
out
.
mData
,
out_ref
.
mData
);
pass
&=
ck
::
utils
::
check_err
(
out
.
mData
,
out_ref
.
mData
);
if
(
NeedIndices
)
if
(
NeedIndices
)
{
{
out_indices_dev
.
FromDevice
(
out_indices
.
mData
.
data
());
out_indices_dev
.
FromDevice
(
out_indices
.
mData
.
data
());
ck
::
utils
::
check_err
(
out_indices
.
mData
,
out_indices_ref
.
mData
);
pass
&=
ck
::
utils
::
check_err
(
out_indices
.
mData
,
out_indices_ref
.
mData
);
;
};
};
};
};
return
pass
?
0
:
1
;
}
}
example/13_pool2d_fwd/pool2d_fwd.cpp
View file @
e00a943e
...
@@ -149,9 +149,9 @@ int main(int argc, char* argv[])
...
@@ -149,9 +149,9 @@ int main(int argc, char* argv[])
{
{
using
namespace
ck
::
host_reduce
;
using
namespace
ck
::
host_reduce
;
bool
do_verification
=
0
;
bool
do_verification
=
true
;
int
init_method
=
0
;
int
init_method
=
1
;
int
nrepeat
=
5
;
bool
time_kernel
=
false
;
// Pool shape
// Pool shape
ck
::
index_t
N
=
128
;
ck
::
index_t
N
=
128
;
...
@@ -171,13 +171,13 @@ int main(int argc, char* argv[])
...
@@ -171,13 +171,13 @@ int main(int argc, char* argv[])
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
16
)
else
if
(
argc
==
16
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
N
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
4
]);
C
=
std
::
stoi
(
argv
[
5
]);
C
=
std
::
stoi
(
argv
[
5
]);
...
@@ -196,7 +196,7 @@ int main(int argc, char* argv[])
...
@@ -196,7 +196,7 @@ int main(int argc, char* argv[])
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3:
run
kernel
# of times (>1
)
\n
"
);
printf
(
"arg3:
time
kernel
(0=n0, 1=yes
)
\n
"
);
printf
(
"arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, "
printf
(
"arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, "
"RightPx
\n
"
);
"RightPx
\n
"
);
exit
(
0
);
exit
(
0
);
...
@@ -271,7 +271,7 @@ int main(int argc, char* argv[])
...
@@ -271,7 +271,7 @@ int main(int argc, char* argv[])
"not support this problem"
);
"not support this problem"
);
}
}
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
nrepeat
);
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
}
);
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
N
*
C
*
Ho
*
Wo
*
Y
*
X
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
N
*
C
*
Ho
*
Wo
*
Y
*
X
;
...
@@ -285,6 +285,7 @@ int main(int argc, char* argv[])
...
@@ -285,6 +285,7 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
<<
std
::
endl
;
bool
pass
=
true
;
if
(
do_verification
)
if
(
do_verification
)
{
{
pool_host_verify
<
InDataType
,
pool_host_verify
<
InDataType
,
...
@@ -302,14 +303,15 @@ int main(int argc, char* argv[])
...
@@ -302,14 +303,15 @@ int main(int argc, char* argv[])
out_device_buf
.
FromDevice
(
out_n_c_ho_wo_device
.
mData
.
data
());
out_device_buf
.
FromDevice
(
out_n_c_ho_wo_device
.
mData
.
data
());
ck
::
utils
::
check_err
(
out_n_c_ho_wo_device
.
mData
,
out_n_c_ho_wo_host
.
mData
);
pass
&=
ck
::
utils
::
check_err
(
out_n_c_ho_wo_device
.
mData
,
out_n_c_ho_wo_host
.
mData
);
if
constexpr
(
NeedIndices
)
if
constexpr
(
NeedIndices
)
{
{
out_indices_device_buf
.
FromDevice
(
out_indices_n_c_ho_wo_device
.
mData
.
data
());
out_indices_device_buf
.
FromDevice
(
out_indices_n_c_ho_wo_device
.
mData
.
data
());
//
ck::utils::check_err(out_indices_n_c_ho_wo_device.mData,
pass
&=
ck
::
utils
::
check_err
(
out_indices_n_c_ho_wo_device
.
mData
,
//
out_indices_n_c_ho_wo_host.mData);
;
out_indices_n_c_ho_wo_host
.
mData
);
};
};
}
}
return
pass
?
0
:
1
;
}
}
example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp
View file @
e00a943e
...
@@ -105,9 +105,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
...
@@ -105,9 +105,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
bool
do_verification
=
0
;
bool
do_verification
=
true
;
int
init_method
=
0
;
int
init_method
=
1
;
int
nrepeat
=
5
;
bool
time_kernel
=
false
;
// GEMM shape
// GEMM shape
ck
::
index_t
M
=
3840
;
ck
::
index_t
M
=
3840
;
...
@@ -125,13 +125,13 @@ int main(int argc, char* argv[])
...
@@ -125,13 +125,13 @@ int main(int argc, char* argv[])
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
10
)
else
if
(
argc
==
10
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
N
=
std
::
stoi
(
argv
[
5
]);
...
@@ -145,7 +145,7 @@ int main(int argc, char* argv[])
...
@@ -145,7 +145,7 @@ int main(int argc, char* argv[])
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3:
run
kernel
# of times (>1
)
\n
"
);
printf
(
"arg3:
time
kernel
(0=n0, 1=yes
)
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
...
@@ -219,7 +219,7 @@ int main(int argc, char* argv[])
...
@@ -219,7 +219,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"
);
"not support this GEMM problem"
);
}
}
float
ave_time
=
invoker
.
Run
(
argument
,
nrepeat
);
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
}
);
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
std
::
size_t
num_btype
=
...
@@ -244,7 +244,7 @@ int main(int argc, char* argv[])
...
@@ -244,7 +244,7 @@ int main(int argc, char* argv[])
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
ck
::
utils
::
check_err
(
c_m_n_device_result
.
mData
,
c_m_n_host_result
.
mData
);
return
ck
::
utils
::
check_err
(
c_m_n_device_result
.
mData
,
c_m_n_host_result
.
mData
)
?
0
:
1
;
}
}
return
0
;
return
0
;
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
View file @
e00a943e
...
@@ -60,21 +60,21 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
...
@@ -60,21 +60,21 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
bool
do_verification
=
0
;
bool
do_verification
=
true
;
int
init_method
=
0
;
int
init_method
=
1
;
int
nrepeat
=
5
;
bool
time_kernel
=
false
;
if
(
argc
==
4
)
if
(
argc
==
4
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
else
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3:
run
kernel
# of times (>1
)
\n
"
);
printf
(
"arg3:
time
kernel
(0=n0, 1=yes
)
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
...
@@ -202,7 +202,7 @@ int main(int argc, char* argv[])
...
@@ -202,7 +202,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"
);
"not support this GEMM problem"
);
}
}
float
ave_time
=
invoker
.
Run
(
argument
,
nrepeat
);
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
}
);
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
@@ -211,6 +211,7 @@ int main(int argc, char* argv[])
...
@@ -211,6 +211,7 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
bool
pass
=
true
;
if
(
do_verification
)
if
(
do_verification
)
{
{
for
(
std
::
size_t
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
...
@@ -227,9 +228,9 @@ int main(int argc, char* argv[])
...
@@ -227,9 +228,9 @@ int main(int argc, char* argv[])
c_element_op
);
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
ck
::
utils
::
check_err
(
c_device_tensors
[
i
].
mData
,
c_host_tensors
[
i
].
mData
);
pass
&=
ck
::
utils
::
check_err
(
c_device_tensors
[
i
].
mData
,
c_host_tensors
[
i
].
mData
);
}
}
}
}
return
0
;
return
pass
?
0
:
1
;
}
}
example/16_gemm_reduce/gemm_reduce_xdl_fp16.cpp
View file @
e00a943e
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#include <cstdlib>
#include <cstdlib>
#include <stdlib.h>
#include <stdlib.h>
#include <half.hpp>
#include <half.hpp>
#include "check_err.hpp"
#include "config.hpp"
#include "config.hpp"
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
...
@@ -58,9 +59,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
...
@@ -58,9 +59,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
bool
do_verification
=
1
;
bool
do_verification
=
true
;
int
init_method
=
1
;
int
init_method
=
1
;
int
nrepeat
=
5
;
bool
time_kernel
=
false
;
// GEMM shape
// GEMM shape
ck
::
index_t
M
=
3840
;
ck
::
index_t
M
=
3840
;
...
@@ -79,13 +80,13 @@ int main(int argc, char* argv[])
...
@@ -79,13 +80,13 @@ int main(int argc, char* argv[])
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
10
)
else
if
(
argc
==
10
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
N
=
std
::
stoi
(
argv
[
5
]);
...
@@ -99,7 +100,7 @@ int main(int argc, char* argv[])
...
@@ -99,7 +100,7 @@ int main(int argc, char* argv[])
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3:
run
kernel
# of times (>1
)
\n
"
);
printf
(
"arg3:
time
kernel
(0=n0, 1=yes
)
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
...
@@ -192,30 +193,13 @@ int main(int argc, char* argv[])
...
@@ -192,30 +193,13 @@ int main(int argc, char* argv[])
"not support this GEMM problem"
);
"not support this GEMM problem"
);
}
}
// warm up
// init DO, D1 to 0
invoker
.
Run
(
argument
);
d0_device_buf
.
SetZero
();
d1_device_buf
.
SetZero
();
// timing
// if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result
float
total_time
=
0
;
// will not be correct. need to set time_kernel = false for correctness test
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
// init DO, D1 to 0
d0_device_buf
.
SetZero
();
d1_device_buf
.
SetZero
();
KernelTimer
timer
;
timer
.
Start
();
invoker
.
Run
(
argument
);
timer
.
End
();
total_time
+=
timer
.
GetElapsedTime
();
}
float
ave_time
=
total_time
/
nrepeat
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
std
::
size_t
num_btype
=
...
@@ -228,6 +212,7 @@ int main(int argc, char* argv[])
...
@@ -228,6 +212,7 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
bool
pass
=
true
;
if
(
do_verification
)
if
(
do_verification
)
{
{
c_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
c_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
...
@@ -264,10 +249,19 @@ int main(int argc, char* argv[])
...
@@ -264,10 +249,19 @@ int main(int argc, char* argv[])
d1_m_host_result
(
m
)
=
ck
::
type_convert
<
DDataType
>
(
d1_acc
);
d1_m_host_result
(
m
)
=
ck
::
type_convert
<
DDataType
>
(
d1_acc
);
}
}
check_error
(
c_m_n_host_result
,
c_m_n_device_result
);
pass
&=
ck
::
utils
::
check_err
(
check_error
(
d0_m_host_result
,
d0_m_device_result
);
c_m_n_device_result
.
mData
,
c_m_n_host_result
.
mData
,
"Error: Incorrect results c"
);
check_error
(
d1_m_host_result
,
d1_m_device_result
);
pass
&=
ck
::
utils
::
check_err
(
d0_m_device_result
.
mData
,
d0_m_host_result
.
mData
,
"Error: Incorrect results d0"
,
1e-3
,
1e-3
);
pass
&=
ck
::
utils
::
check_err
(
d1_m_device_result
.
mData
,
d1_m_host_result
.
mData
,
"Error: Incorrect results d1"
,
1e-3
,
1e-3
);
}
}
return
0
;
return
pass
?
0
:
1
;
}
}
example/17_convnd_bwd_data_xdl/CMakeLists.txt
View file @
e00a943e
add_example_executable
(
example_convnd_bwd_data_xdl convnd_bwd_data_xdl.cpp
)
add_example_executable
(
example_convnd_bwd_data_xdl convnd_bwd_data_xdl.cpp
)
target_link_libraries
(
example_convnd_bwd_data_xdl PRIVATE conv_
fwd_
util
)
target_link_libraries
(
example_convnd_bwd_data_xdl PRIVATE conv_util
)
example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp
View file @
e00a943e
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include <half.hpp>
#include <half.hpp>
#include "config.hpp"
#include "config.hpp"
#include "conv_
fwd_
util.hpp"
#include "conv_util.hpp"
#include "print.hpp"
#include "print.hpp"
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
...
@@ -87,7 +87,7 @@ void print_use_msg()
...
@@ -87,7 +87,7 @@ void print_use_msg()
{
{
std
::
cout
<<
"arg1: verification (0=no, 1=yes)
\n
"
std
::
cout
<<
"arg1: verification (0=no, 1=yes)
\n
"
<<
"arg2: initialization (0=no init, 1=random value, 2= init to 1 )
\n
"
<<
"arg2: initialization (0=no init, 1=random value, 2= init to 1 )
\n
"
<<
"arg3:
run
kernel
# of times (>1
)
\n
"
<<
"arg3:
time
kernel
(0=n0, 1=yes
)
\n
"
<<
"arg4: N spatial dimensions (default 2)
\n
"
<<
"arg4: N spatial dimensions (default 2)
\n
"
<<
"Following arguments (depending on number of spatial dims):
\n
"
<<
"Following arguments (depending on number of spatial dims):
\n
"
<<
" N, K, C,
\n
"
<<
" N, K, C,
\n
"
...
@@ -105,40 +105,40 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[])
...
@@ -105,40 +105,40 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[])
ck
::
utils
::
conv
::
ConvParams
params
;
ck
::
utils
::
conv
::
ConvParams
params
;
int
arg_idx
=
5
;
int
arg_idx
=
5
;
params
.
num_dim_spatial
=
num_dim_spatial
;
params
.
num_dim_spatial
_
=
num_dim_spatial
;
params
.
N
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
N
_
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
K
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
K
_
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
C
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
C
_
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
filter_spatial_lengths
.
resize
(
num_dim_spatial
);
params
.
filter_spatial_lengths
_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
{
params
.
filter_spatial_lengths
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
filter_spatial_lengths
_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
}
params
.
input_spatial_lengths
.
resize
(
num_dim_spatial
);
params
.
input_spatial_lengths
_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
{
params
.
input_spatial_lengths
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
input_spatial_lengths
_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
}
params
.
conv_filter_strides
.
resize
(
num_dim_spatial
);
params
.
conv_filter_strides
_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
{
params
.
conv_filter_strides
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
conv_filter_strides
_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
}
params
.
conv_filter_dilations
.
resize
(
num_dim_spatial
);
params
.
conv_filter_dilations
_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
{
params
.
conv_filter_dilations
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
conv_filter_dilations
_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
}
params
.
input_left_pads
.
resize
(
num_dim_spatial
);
params
.
input_left_pads
_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
{
params
.
input_left_pads
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
input_left_pads
_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
}
params
.
input_right_pads
.
resize
(
num_dim_spatial
);
params
.
input_right_pads
_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
{
params
.
input_right_pads
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
input_right_pads
_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
}
return
params
;
return
params
;
...
@@ -165,25 +165,25 @@ DeviceConvBwdDataBasePtr get_conv_instance(int num_dim_spatial)
...
@@ -165,25 +165,25 @@ DeviceConvBwdDataBasePtr get_conv_instance(int num_dim_spatial)
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
bool
do_verification
=
0
;
bool
do_verification
=
true
;
int
init_method
=
0
;
int
init_method
=
1
;
int
nrepeat
=
5
;
bool
time_kernel
=
false
;
int
num_dim_spatial
=
2
;
int
num_dim_spatial
=
2
;
ck
::
utils
::
conv
::
ConvParams
params
;
ck
::
utils
::
conv
::
ConvParams
params
;
params
.
C
=
128
;
params
.
C
_
=
128
;
if
(
argc
==
4
)
if
(
argc
==
4
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
>
4
)
else
if
(
argc
>
4
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
num_dim_spatial
=
std
::
stoi
(
argv
[
4
]);
num_dim_spatial
=
std
::
stoi
(
argv
[
4
]);
// check args number
// check args number
int
conv_args
=
3
+
num_dim_spatial
*
6
;
int
conv_args
=
3
+
num_dim_spatial
*
6
;
...
@@ -202,21 +202,21 @@ int main(int argc, char* argv[])
...
@@ -202,21 +202,21 @@ int main(int argc, char* argv[])
exit
(
1
);
exit
(
1
);
}
}
std
::
vector
<
std
::
size_t
>
input_dims
{
static_cast
<
std
::
size_t
>
(
params
.
N
),
std
::
vector
<
std
::
size_t
>
input_dims
{
static_cast
<
std
::
size_t
>
(
params
.
N
_
),
static_cast
<
std
::
size_t
>
(
params
.
C
)};
static_cast
<
std
::
size_t
>
(
params
.
C
_
)};
input_dims
.
insert
(
std
::
end
(
input_dims
),
input_dims
.
insert
(
std
::
end
(
input_dims
),
std
::
begin
(
params
.
input_spatial_lengths
),
std
::
begin
(
params
.
input_spatial_lengths
_
),
std
::
end
(
params
.
input_spatial_lengths
));
std
::
end
(
params
.
input_spatial_lengths
_
));
std
::
vector
<
std
::
size_t
>
filter_dims
{
static_cast
<
std
::
size_t
>
(
params
.
K
),
std
::
vector
<
std
::
size_t
>
filter_dims
{
static_cast
<
std
::
size_t
>
(
params
.
K
_
),
static_cast
<
std
::
size_t
>
(
params
.
C
)};
static_cast
<
std
::
size_t
>
(
params
.
C
_
)};
filter_dims
.
insert
(
std
::
end
(
filter_dims
),
filter_dims
.
insert
(
std
::
end
(
filter_dims
),
std
::
begin
(
params
.
filter_spatial_lengths
),
std
::
begin
(
params
.
filter_spatial_lengths
_
),
std
::
end
(
params
.
filter_spatial_lengths
));
std
::
end
(
params
.
filter_spatial_lengths
_
));
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
=
params
.
GetOutputSpatialLengths
();
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
=
params
.
GetOutputSpatialLengths
();
std
::
vector
<
std
::
size_t
>
output_dims
{
static_cast
<
std
::
size_t
>
(
params
.
N
),
std
::
vector
<
std
::
size_t
>
output_dims
{
static_cast
<
std
::
size_t
>
(
params
.
N
_
),
static_cast
<
std
::
size_t
>
(
params
.
K
)};
static_cast
<
std
::
size_t
>
(
params
.
K
_
)};
output_dims
.
insert
(
std
::
end
(
output_dims
),
output_dims
.
insert
(
std
::
end
(
output_dims
),
std
::
begin
(
output_spatial_lengths
),
std
::
begin
(
output_spatial_lengths
),
std
::
end
(
output_spatial_lengths
));
std
::
end
(
output_spatial_lengths
));
...
@@ -263,16 +263,16 @@ int main(int argc, char* argv[])
...
@@ -263,16 +263,16 @@ int main(int argc, char* argv[])
conv
->
MakeArgumentPointer
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
conv
->
MakeArgumentPointer
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
params
.
N
,
params
.
N
_
,
params
.
K
,
params
.
K
_
,
params
.
C
,
params
.
C
_
,
params
.
input_spatial_lengths
,
params
.
input_spatial_lengths
_
,
params
.
filter_spatial_lengths
,
params
.
filter_spatial_lengths
_
,
output_spatial_lengths
,
output_spatial_lengths
,
params
.
conv_filter_strides
,
params
.
conv_filter_strides
_
,
params
.
conv_filter_dilations
,
params
.
conv_filter_dilations
_
,
params
.
input_left_pads
,
params
.
input_left_pads
_
,
params
.
input_right_pads
,
params
.
input_right_pads
_
,
InElementOp
{},
InElementOp
{},
WeiElementOp
{},
WeiElementOp
{},
OutElementOp
{});
OutElementOp
{});
...
@@ -284,16 +284,16 @@ int main(int argc, char* argv[])
...
@@ -284,16 +284,16 @@ int main(int argc, char* argv[])
"not support this Conv problem"
);
"not support this Conv problem"
);
}
}
float
ave_time
=
invoker
->
Run
(
argument
.
get
(),
nrepeat
);
float
ave_time
=
invoker
->
Run
(
argument
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
}
);
std
::
size_t
flop
=
ck
::
utils
::
conv
::
get_flops
(
std
::
size_t
flop
=
ck
::
utils
::
conv
::
get_flops
(
params
.
N
,
params
.
C
,
params
.
K
,
params
.
filter_spatial_lengths
,
output_spatial_lengths
);
params
.
N
_
,
params
.
C
_
,
params
.
K
_
,
params
.
filter_spatial_lengths
_
,
output_spatial_lengths
);
std
::
size_t
num_btype
=
ck
::
utils
::
conv
::
get_btype
<
InDataType
,
WeiDataType
,
OutDataType
>
(
std
::
size_t
num_btype
=
ck
::
utils
::
conv
::
get_btype
<
InDataType
,
WeiDataType
,
OutDataType
>
(
params
.
N
,
params
.
N
_
,
params
.
C
,
params
.
C
_
,
params
.
K
,
params
.
K
_
,
params
.
input_spatial_lengths
,
params
.
input_spatial_lengths
_
,
params
.
filter_spatial_lengths
,
params
.
filter_spatial_lengths
_
,
output_spatial_lengths
);
output_spatial_lengths
);
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
@@ -310,10 +310,10 @@ int main(int argc, char* argv[])
...
@@ -310,10 +310,10 @@ int main(int argc, char* argv[])
auto
ref_argument
=
ref_conv
.
MakeArgument
(
in_n_c_hi_wi_host_result
,
auto
ref_argument
=
ref_conv
.
MakeArgument
(
in_n_c_hi_wi_host_result
,
wei_k_c_y_x
,
wei_k_c_y_x
,
out_n_k_ho_wo
,
out_n_k_ho_wo
,
params
.
conv_filter_strides
,
params
.
conv_filter_strides
_
,
params
.
conv_filter_dilations
,
params
.
conv_filter_dilations
_
,
params
.
input_left_pads
,
params
.
input_left_pads
_
,
params
.
input_right_pads
,
params
.
input_right_pads
_
,
InElementOp
{},
InElementOp
{},
WeiElementOp
{},
WeiElementOp
{},
OutElementOp
{});
OutElementOp
{});
...
@@ -322,7 +322,10 @@ int main(int argc, char* argv[])
...
@@ -322,7 +322,10 @@ int main(int argc, char* argv[])
in_device_buf
.
FromDevice
(
in_n_c_hi_wi_device_result
.
mData
.
data
());
in_device_buf
.
FromDevice
(
in_n_c_hi_wi_device_result
.
mData
.
data
());
check_error
(
in_n_c_hi_wi_host_result
,
in_n_c_hi_wi_device_result
);
return
ck
::
utils
::
check_err
(
in_n_c_hi_wi_device_result
.
mData
,
in_n_c_hi_wi_host_result
.
mData
)
?
0
:
1
;
};
};
switch
(
num_dim_spatial
)
switch
(
num_dim_spatial
)
...
@@ -347,4 +350,5 @@ int main(int argc, char* argv[])
...
@@ -347,4 +350,5 @@ int main(int argc, char* argv[])
}
}
}
}
}
}
return
0
;
}
}
example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp
View file @
e00a943e
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#include <cstdlib>
#include <cstdlib>
#include <stdlib.h>
#include <stdlib.h>
#include <half.hpp>
#include <half.hpp>
#include "check_err.hpp"
#include "config.hpp"
#include "config.hpp"
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
...
@@ -57,18 +58,18 @@ using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
...
@@ -57,18 +58,18 @@ using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
bool
do_verification
=
1
;
bool
do_verification
=
true
;
int
init_method
=
1
;
int
init_method
=
1
;
int
nrepeat
=
5
;
bool
time_kernel
=
false
;
// GEMM shape
// GEMM shape
ck
::
index_t
M
=
3840
;
ck
::
index_t
M
=
2048
;
ck
::
index_t
N
=
4096
;
ck
::
index_t
N
=
1920
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
K
=
2048
;
ck
::
index_t
StrideA
=
4096
;
ck
::
index_t
StrideA
=
2048
;
ck
::
index_t
StrideB
=
4096
;
ck
::
index_t
StrideB
=
2048
;
ck
::
index_t
StrideC
=
4096
;
ck
::
index_t
StrideC
=
1920
;
ck
::
index_t
BatchCount
=
4
;
ck
::
index_t
BatchCount
=
4
;
...
@@ -80,13 +81,13 @@ int main(int argc, char* argv[])
...
@@ -80,13 +81,13 @@ int main(int argc, char* argv[])
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
11
)
else
if
(
argc
==
11
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
N
=
std
::
stoi
(
argv
[
5
]);
...
@@ -96,13 +97,13 @@ int main(int argc, char* argv[])
...
@@ -96,13 +97,13 @@ int main(int argc, char* argv[])
StrideB
=
std
::
stoi
(
argv
[
8
]);
StrideB
=
std
::
stoi
(
argv
[
8
]);
StrideC
=
std
::
stoi
(
argv
[
9
]);
StrideC
=
std
::
stoi
(
argv
[
9
]);
BatchCount
=
std
::
stoi
(
argv
[
9
]);
BatchCount
=
std
::
stoi
(
argv
[
10
]);
}
}
else
else
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3:
run
kernel
# of times (>1
)
\n
"
);
printf
(
"arg3:
time
kernel
(0=n0, 1=yes
)
\n
"
);
printf
(
"arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, BatchCount
\n
"
);
printf
(
"arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, BatchCount
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
...
@@ -204,30 +205,13 @@ int main(int argc, char* argv[])
...
@@ -204,30 +205,13 @@ int main(int argc, char* argv[])
"not support this GEMM problem"
);
"not support this GEMM problem"
);
}
}
// warm up
// init DO, D1 to 0
invoker
.
Run
(
argument
);
d0_device_buf
.
SetZero
();
d1_device_buf
.
SetZero
();
// timing
// if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result
float
total_time
=
0
;
// will not be correct. need to set time_kernel = false for correctness test
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
// init DO, D1 to 0
d0_device_buf
.
SetZero
();
d1_device_buf
.
SetZero
();
KernelTimer
timer
;
timer
.
Start
();
invoker
.
Run
(
argument
);
timer
.
End
();
total_time
+=
timer
.
GetElapsedTime
();
}
float
ave_time
=
total_time
/
nrepeat
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
BatchCount
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
BatchCount
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
BatchCount
*
M
*
K
+
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
BatchCount
*
M
*
K
+
...
@@ -241,6 +225,7 @@ int main(int argc, char* argv[])
...
@@ -241,6 +225,7 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
batched_gemm
.
GetTypeString
()
<<
std
::
endl
;
<<
batched_gemm
.
GetTypeString
()
<<
std
::
endl
;
bool
pass
=
true
;
if
(
do_verification
)
if
(
do_verification
)
{
{
c_device_buf
.
FromDevice
(
c_g_m_n_device_result
.
mData
.
data
());
c_device_buf
.
FromDevice
(
c_g_m_n_device_result
.
mData
.
data
());
...
@@ -264,7 +249,7 @@ int main(int argc, char* argv[])
...
@@ -264,7 +249,7 @@ int main(int argc, char* argv[])
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
{
float
d0_val
=
ck
::
type_convert
<
float
>
(
c_g_m_n_host_result
(
m
,
n
));
float
d0_val
=
ck
::
type_convert
<
float
>
(
c_g_m_n_host_result
(
batch
,
m
,
n
));
float
d1_val
;
float
d1_val
;
d1_element_op
(
d1_val
,
d0_val
);
d1_element_op
(
d1_val
,
d0_val
);
...
@@ -277,10 +262,18 @@ int main(int argc, char* argv[])
...
@@ -277,10 +262,18 @@ int main(int argc, char* argv[])
}
}
}
}
check_error
(
c_g_m_n_host_result
,
c_g_m_n_device_result
);
pass
&=
ck
::
utils
::
check_err
(
c_g_m_n_host_result
.
mData
,
c_g_m_n_device_result
.
mData
);
check_error
(
d0_g_m_host_result
,
d0_g_m_device_result
);
pass
&=
ck
::
utils
::
check_err
(
d0_g_m_device_result
.
mData
,
check_error
(
d1_g_m_host_result
,
d1_g_m_device_result
);
d0_g_m_host_result
.
mData
,
"Error: Incorrect results! D0"
,
1e-3
,
1e-3
);
pass
&=
ck
::
utils
::
check_err
(
d1_g_m_device_result
.
mData
,
d1_g_m_host_result
.
mData
,
"Error: Incorrect results! D1"
,
1e-3
,
1e-3
);
}
}
return
0
;
return
pass
?
0
:
1
;
}
}
example/19_cgemm/cgemm_xdl_bf16.cpp
View file @
e00a943e
...
@@ -88,9 +88,9 @@ using ReferenceCGemmInstance = ck::tensor_operation::host::
...
@@ -88,9 +88,9 @@ using ReferenceCGemmInstance = ck::tensor_operation::host::
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
bool
do_verification
=
0
;
bool
do_verification
=
true
;
int
init_method
=
0
;
int
init_method
=
1
;
int
nrepeat
=
5
;
bool
time_kernel
=
false
;
// CGEMM shape
// CGEMM shape
ck
::
index_t
M
=
3840
;
ck
::
index_t
M
=
3840
;
...
@@ -105,13 +105,13 @@ int main(int argc, char* argv[])
...
@@ -105,13 +105,13 @@ int main(int argc, char* argv[])
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
10
)
else
if
(
argc
==
10
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
N
=
std
::
stoi
(
argv
[
5
]);
...
@@ -223,7 +223,7 @@ int main(int argc, char* argv[])
...
@@ -223,7 +223,7 @@ int main(int argc, char* argv[])
"not support this CGEMM problem"
);
"not support this CGEMM problem"
);
}
}
float
ave_time
=
invoker
.
Run
(
argument
,
nrepeat
);
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
}
);
std
::
size_t
flop
=
std
::
size_t
(
8
)
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
8
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
std
::
size_t
(
2
)
*
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
std
::
size_t
num_btype
=
std
::
size_t
(
2
)
*
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
...
...
example/CMakeLists.txt
View file @
e00a943e
...
@@ -19,9 +19,18 @@ include_directories(BEFORE
...
@@ -19,9 +19,18 @@ include_directories(BEFORE
add_custom_target
(
examples
)
add_custom_target
(
examples
)
function
(
add_example_executable EXAMPLE_NAME
)
function
(
add_example_executable EXAMPLE_NAME
FILE_NAME
)
message
(
"adding example
${
EXAMPLE_NAME
}
"
)
message
(
"adding example
${
EXAMPLE_NAME
}
"
)
add_executable
(
${
EXAMPLE_NAME
}
${
ARGN
}
)
add_executable
(
${
EXAMPLE_NAME
}
${
FILE_NAME
}
)
target_link_libraries
(
${
EXAMPLE_NAME
}
PRIVATE host_tensor
)
add_test
(
NAME
${
EXAMPLE_NAME
}
COMMAND $<TARGET_FILE:
${
EXAMPLE_NAME
}
>
${
ARGN
}
)
add_dependencies
(
examples
${
EXAMPLE_NAME
}
)
add_dependencies
(
check
${
EXAMPLE_NAME
}
)
endfunction
(
add_example_executable EXAMPLE_NAME
)
function
(
add_example_executable_no_testing EXAMPLE_NAME FILE_NAME
)
message
(
"adding example
${
EXAMPLE_NAME
}
"
)
add_executable
(
${
EXAMPLE_NAME
}
${
FILE_NAME
}
)
target_link_libraries
(
${
EXAMPLE_NAME
}
PRIVATE host_tensor
)
target_link_libraries
(
${
EXAMPLE_NAME
}
PRIVATE host_tensor
)
add_dependencies
(
examples
${
EXAMPLE_NAME
}
)
add_dependencies
(
examples
${
EXAMPLE_NAME
}
)
endfunction
(
add_example_executable EXAMPLE_NAME
)
endfunction
(
add_example_executable EXAMPLE_NAME
)
...
...
include/ck/config.hpp
View file @
e00a943e
...
@@ -109,6 +109,10 @@
...
@@ -109,6 +109,10 @@
// experimental feature: use __builtin_memcpy instead of union to do bit_cast
// experimental feature: use __builtin_memcpy instead of union to do bit_cast
#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST 1
#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST 1
// experimental feature: optimize for inter-wave scheduling policy
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING 0
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS 1
// hack: have underlying assumption that need to be satsified, otherwise it's a bug
// hack: have underlying assumption that need to be satsified, otherwise it's a bug
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
// thread-invariant, otherwise it's a bug
// thread-invariant, otherwise it's a bug
...
...
include/ck/hip_version.hpp.in
deleted
100644 → 0
View file @
ffe12e2e
#pragma once
// "_PACKAGE_" to avoid name contentions: the macros like
// HIP_VERSION_MAJOR are defined in HIP_VERSION.h.
// clang-format off
#define CK_HIP_PACKAGE_VERSION_MAJOR @CK_HIP_VERSION_MAJOR@
#define CK_HIP_PACKAGE_VERSION_MINOR @CK_HIP_VERSION_MINOR@
#define CK_HIP_PACKAGE_VERSION_PATCH @CK_HIP_VERSION_PATCH@
// clang-format on
#ifndef CK_HIP_PACKAGE_VERSION_MAJOR
#define CK_HIP_PACKAGE_VERSION_MAJOR 0
#endif
#ifndef CK_HIP_PACKAGE_VERSION_MINOR
#define CK_HIP_PACKAGE_VERSION_MINOR 0
#endif
#ifndef CK_HIP_PACKAGE_VERSION_PATCH
#define CK_HIP_PACKAGE_VERSION_PATCH 0
#endif
// 3 decimal digits for major and minor, 6 digits for patch number.
// Max number is 999,999,999999 == 0xE8,D4A5,0FFF that fits into 64-bit math.
#if CK_HIP_PACKAGE_VERSION_MAJOR > 999 || CK_HIP_PACKAGE_VERSION_MAJOR > 999 || \
CK_HIP_PACKAGE_VERSION_PATCH > 999999
#error "Too big HIP version number(s)"
#endif
#define CK_HIP_PACKAGE_VERSION_FLAT \
((CK_HIP_PACKAGE_VERSION_MAJOR * 1000ULL + CK_HIP_PACKAGE_VERSION_MINOR) * 1000000 + \
CK_HIP_PACKAGE_VERSION_PATCH)
include/ck/options.hpp.in
0 → 100644
View file @
e00a943e
#pragma once
#cmakedefine01 CK_TIME_KERNEL
include/ck/stream_config.hpp
0 → 100644
View file @
e00a943e
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
struct
StreamConfig
{
hipStream_t
stream_id_
=
nullptr
;
bool
time_kernel_
=
false
;
};
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
e00a943e
...
@@ -7,6 +7,21 @@
...
@@ -7,6 +7,21 @@
namespace
ck
{
namespace
ck
{
enum
struct
LoopScheduler
{
Default
,
Interwave
,
};
constexpr
LoopScheduler
make_default_loop_scheduler
()
{
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
return
LoopScheduler
::
Interwave
;
#else
return
LoopScheduler
::
Default
;
#endif // if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
}
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatAcc
,
...
@@ -302,7 +317,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -302,7 +317,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
});
});
}
}
pr
ivate
:
pr
otected
:
// A[M0, M1, M2, KPerThread]
// A[M0, M1, M2, KPerThread]
static
constexpr
auto
a_thread_desc_
=
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
KPerThread
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
KPerThread
>
{}));
...
@@ -339,4 +354,232 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -339,4 +354,232 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()};
};
};
// Note: To facilitate the inter-wave loop scheduler, we need to explicitly set the macro
// CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=1 as a few intrinsics are not yet available in
// the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the
// default loop scheduler which is given by the macro CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=0
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
BK0NK1BlockDesc
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
,
index_t
NumMacClusters
=
CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
>
struct
BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
:
public
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
AK0MK1BlockDesc
,
BK0NK1BlockDesc
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{
using
Base
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
AK0MK1BlockDesc
,
BK0NK1BlockDesc
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
A_K1
;
using
Base
::
b_block_desc_n0_n1_n2_k
;
using
Base
::
B_K1
;
using
Base
::
c_thread_buf_
;
using
Base
::
c_thread_desc_
;
using
Base
::
CalculateAThreadOriginDataIndex
;
using
Base
::
CalculateBThreadOriginDataIndex
;
using
Base
::
I0
;
using
Base
::
I1
;
using
Base
::
KPerThread
;
using
Base
::
xdlops_gemm
;
static
constexpr
index_t
KPerInnerLoop
=
math
::
max
(
KPerThread
/
NumMacClusters
,
KPack
);
// 2-wave optimized blockwise gemm
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
KPerThread
,
KPerInnerLoop
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
k
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
k
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_thread_buf
);
});
__builtin_amdgcn_sched_barrier
();
// NOTE: Synchronize threads in a workgroup at the start of each MAC cluster, but except
// the first, as we can shorten non-MAC cluster a bit and there's no observable negative
// impact. The desired effect is waves in a workgroup executing MAC in sync. This avoids
// some out-of-sync waves hijacking MAC resource from other workgroups and reducing the
// chance of latency hiding by waiting for the rest of the workgroup at the eventual
// sync point.
if
constexpr
(
k
.
value
!=
0
||
KPerInnerLoop
==
KPerThread
)
{
asm
volatile
(
"s_barrier"
::
);
__builtin_amdgcn_sched_barrier
();
}
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
0
,
0
,
k_
+
i
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
0
,
0
,
k_
+
i
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from blockwise_gemm is
// moved here B) reduce VMEM FIFO congestion by applying small delays to
// different wavefronts It is performed near the end of MAC cluster to
// minimize lgkmcnt penalty
if
constexpr
(
k
.
value
==
KPerThread
-
KPerInnerLoop
&&
k_
.
value
==
KPerInnerLoop
-
KPack
&&
m0
.
value
==
MRepeat
-
1
&&
n0
.
value
==
NRepeat
-
1
)
{
__builtin_amdgcn_sched_barrier
();
block_sync_lds
();
__builtin_amdgcn_sched_barrier
();
}
// TODO: insert setprio in more precise manner since we
// could have more than >1 MFMA instructions in single call
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
if
constexpr
(
k_
.
value
==
0
&&
m0
.
value
==
0
&&
n0
.
value
==
0
)
{
__builtin_amdgcn_sched_barrier
();
__builtin_amdgcn_s_setprio
(
1
);
__builtin_amdgcn_sched_barrier
();
}
});
});
});
__builtin_amdgcn_sched_barrier
();
__builtin_amdgcn_s_setprio
(
0
);
__builtin_amdgcn_sched_barrier
();
});
}
protected:
// A[M0, M1, M2, KPerInnerLoop]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
KPerInnerLoop
>
{}));
// B[N0, N1, N2, KPerInnerLoop]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
NRepeat
>
{},
I1
,
I1
,
Number
<
KPerInnerLoop
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()};
#endif // #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
};
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
BK0NK1BlockDesc
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
,
LoopScheduler
LoopSched
>
constexpr
auto
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
()
{
if
constexpr
(
LoopSched
==
LoopScheduler
::
Default
)
{
return
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
AK0MK1BlockDesc
,
BK0NK1BlockDesc
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
LoopSched
==
LoopScheduler
::
Interwave
)
{
return
BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
AK0MK1BlockDesc
,
BK0NK1BlockDesc
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
};
}
// namespace ck
}
// namespace ck
Prev
1
2
3
4
5
6
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment