Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a3b4c5cb
"vscode:/vscode.git/clone" did not exist on "8fd82858c05996686011033237de9182ea1347fb"
Commit
a3b4c5cb
authored
Jun 03, 2022
by
wangshaojie6
Browse files
merge develop branch and add gridwise pipeline v3
parents
48918ab9
1677cf70
Changes
361
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2532 additions
and
206 deletions
+2532
-206
example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp
...quant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp
+82
-60
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
+20
-14
example/16_gemm_reduce/CMakeLists.txt
example/16_gemm_reduce/CMakeLists.txt
+2
-1
example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp
example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp
+266
-0
example/16_gemm_reduce/gemm_reduce_xdl_mean_squaremean_fp16.cpp
...e/16_gemm_reduce/gemm_reduce_xdl_mean_squaremean_fp16.cpp
+308
-0
example/17_convnd_bwd_data_xdl/CMakeLists.txt
example/17_convnd_bwd_data_xdl/CMakeLists.txt
+1
-1
example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp
example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp
+62
-61
example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp
...e/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp
+83
-66
example/19_binary_elementwise/CMakeLists.txt
example/19_binary_elementwise/CMakeLists.txt
+4
-0
example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp
example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp
+165
-0
example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp
example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp
+124
-0
example/19_binary_elementwise/elementwise_add_1d.cpp
example/19_binary_elementwise/elementwise_add_1d.cpp
+145
-0
example/19_binary_elementwise/elementwise_add_4d.cpp
example/19_binary_elementwise/elementwise_add_4d.cpp
+147
-0
example/20_convnd_bwd_weight_xdl/CMakeLists.txt
example/20_convnd_bwd_weight_xdl/CMakeLists.txt
+2
-0
example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp
example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp
+422
-0
example/21_gemm_layernorm/CMakeLists.txt
example/21_gemm_layernorm/CMakeLists.txt
+1
-0
example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp
example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp
+378
-0
example/22_cgemm/CMakeLists.txt
example/22_cgemm/CMakeLists.txt
+1
-0
example/22_cgemm/cgemm_xdl_fp16.cpp
example/22_cgemm/cgemm_xdl_fp16.cpp
+302
-0
example/CMakeLists.txt
example/CMakeLists.txt
+17
-3
No files found.
example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp
View file @
a3b4c5cb
...
@@ -13,84 +13,106 @@
...
@@ -13,84 +13,106 @@
#include "host_tensor_generator.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_cshuffle.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "element_wise_operation.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
#include "gemm_specialization.hpp"
template
<
ck
::
index_t
...
Is
>
struct
RequantReluRequant
using
S
=
ck
::
Sequence
<
Is
...
>
;
{
// FIXME: We just need one scale for Relu / Leaky Relu / PRelu
RequantReluRequant
(
float
scaleGemm
,
float
scaleRelu
)
:
scaleGemm_
(
scaleGemm
),
scaleRelu_
(
scaleRelu
)
{
}
using
F32
=
float
;
__host__
__device__
constexpr
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
float
gemm_requant
=
scaleGemm_
*
x
;
float
relu
=
gemm_requant
>
0
?
gemm_requant
:
0
;
float
relu_requant
=
scaleRelu_
*
relu
;
y
=
relu_requant
>
127
?
127
:
relu_requant
<
-
128
?
-
128
:
relu_requant
;
}
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
float
scaleGemm_
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
float
scaleRelu_
;
};
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
template
<
ck
::
index_t
...
Is
>
using
RequantReluRequant
=
ck
::
tensor_operation
::
element_wise
::
RequantReluRequant
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
int8_t
;
using
ADataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
CDataType
=
int8_t
;
using
CDataType
=
int8_t
;
using
AccDataType
=
int32_t
;
using
AccDataType
=
int32_t
;
using
CShuffleDataType
=
int32_
t
;
using
CShuffleDataType
=
floa
t
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdl_C_Shuffle
<
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
<
ADataType
,
// ADataType
ALayout
,
// typename ALayout,
BDataType
,
// BDataType
BLayout
,
// typename BLayout,
CDataType
,
// CDataType
CLayout
,
// typename CLayout,
AccDataType
,
// AccDataType
ADataType
,
// typename ADataType,
CShuffleDataType
,
// CShuffleDataType
BDataType
,
// typename BDataType,
ALayout
,
// ALayout
CDataType
,
// typename CDataType,
BLayout
,
// BLayout
AccDataType
,
// typename GemmAccDataType,
CLayout
,
// CLayout
CShuffleDataType
,
// typename CShuffleDataType,
PassThrough
,
// AElementwiseOperation
PassThrough
,
// typename AElementwiseOperation,
PassThrough
,
// BElementwiseOperation
PassThrough
,
// typename BElementwiseOperation,
RequantReluRequant
,
// CElementwiseOperation
RequantReluRequant
,
// typename CElementwiseOperation,
256
,
// BlockSize
GemmDefault
,
// GemmSpecialization GemmSpec,
256
,
// MPerBlock
1
,
// index_t NumGemmKPrefetchStage,
128
,
// NPerBlock
256
,
// index_t BlockSize,
64
,
// KPerBlock
256
,
// index_t MPerBlock,
16
,
// AK1
128
,
// index_t NPerBlock,
16
,
// BK1
64
,
// index_t KPerBlock,
32
,
// MPerXDL
16
,
// index_t AK1,
32
,
// NPerXDL
16
,
// index_t BK1,
4
,
// MXdlPerWave
32
,
// index_t MPerXDL,
2
,
// NXdlPerWave
32
,
// index_t NPerXDL,
S
<
4
,
64
,
1
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
4
,
// index_t MXdlPerWave,
S
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
2
,
// index_t NXdlPerWave,
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
S
<
4
,
64
,
1
>
,
// typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
2
,
// ABlockTransferSrcVectorDim
S
<
1
,
0
,
2
>
,
// typename ABlockTransferThreadClusterArrangeOrder,
16
,
// ABlockTransferSrcScalarPerVector
S
<
1
,
0
,
2
>
,
// typename ABlockTransferSrcAccessOrder,
16
,
// ABlockTransferDstScalarPerVector_K1
2
,
// index_t ABlockTransferSrcVectorDim,
true
,
// ABlockLdsAddExtraM
16
,
// index_t ABlockTransferSrcScalarPerVector,
S
<
4
,
64
,
1
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
16
,
// index_t ABlockTransferDstScalarPerVector_AK1,
S
<
1
,
0
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
1
,
// bool ABlockLdsExtraM,
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
S
<
4
,
64
,
1
>
,
// typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
2
,
// BBlockTransferSrcVectorDim
S
<
1
,
0
,
2
>
,
// typename BBlockTransferThreadClusterArrangeOrder,
16
,
// BBlockTransferSrcScalarPerVector
S
<
1
,
0
,
2
>
,
// typename BBlockTransferSrcAccessOrder,
16
,
// BBlockTransferDstScalarPerVector_K1
2
,
// index_t BBlockTransferSrcVectorDim,
true
,
// BBlockLdsAddExtraN
8
,
// index_t BBlockTransferSrcScalarPerVector,
1
,
// CShuffleMXdlPerWavePerShuffle
8
,
// index_t BBlockTransferDstScalarPerVector_BK1,
1
,
// CShuffleNXdlPerWavePerShuffle
1
,
// bool BBlockLdsExtraN,
S
<
1
,
1
,
64
,
1
,
1
,
4
>
,
// CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
1
,
// index_t CShuffleMXdlPerWavePerShuffle,
16
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
1
,
// index_t CShuffleNXdlPerWavePerShuffle,
S
<
1
,
64
,
1
,
4
>
,
// typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
16
>
;
// index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
RequantReluRequant
>
;
BDataType
,
CDataType
,
float
,
PassThrough
,
PassThrough
,
RequantReluRequant
>
;
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
bool
do_verification
=
0
;
bool
do_verification
=
true
;
int
init_method
=
0
;
int
init_method
=
1
;
int
nrepeat
=
5
;
bool
time_kernel
=
false
;
// GEMM shape
// GEMM shape
ck
::
index_t
M
=
3840
;
ck
::
index_t
M
=
3840
;
...
@@ -108,13 +130,13 @@ int main(int argc, char* argv[])
...
@@ -108,13 +130,13 @@ int main(int argc, char* argv[])
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
10
)
else
if
(
argc
==
10
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
N
=
std
::
stoi
(
argv
[
5
]);
...
@@ -128,7 +150,7 @@ int main(int argc, char* argv[])
...
@@ -128,7 +150,7 @@ int main(int argc, char* argv[])
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3:
run
kernel
# of times (>1
)
\n
"
);
printf
(
"arg3:
time
kernel
(0=n0, 1=yes
)
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
...
@@ -202,7 +224,7 @@ int main(int argc, char* argv[])
...
@@ -202,7 +224,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"
);
"not support this GEMM problem"
);
}
}
float
ave_time
=
invoker
.
Run
(
argument
,
nrepeat
);
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
}
);
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
std
::
size_t
num_btype
=
...
@@ -227,7 +249,7 @@ int main(int argc, char* argv[])
...
@@ -227,7 +249,7 @@ int main(int argc, char* argv[])
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
ck
::
utils
::
check_err
(
c_m_n_device_result
.
mData
,
c_m_n_host_result
.
mData
);
return
ck
::
utils
::
check_err
(
c_m_n_device_result
.
mData
,
c_m_n_host_result
.
mData
)
?
0
:
1
;
}
}
return
0
;
return
0
;
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
View file @
a3b4c5cb
...
@@ -56,29 +56,29 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdl
...
@@ -56,29 +56,29 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdl
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
bool
do_verification
=
0
;
bool
do_verification
=
true
;
int
init_method
=
0
;
int
init_method
=
1
;
int
nrepeat
=
5
;
bool
time_kernel
=
false
;
if
(
argc
==
4
)
if
(
argc
==
4
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
else
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3:
run
kernel
# of times (>1
)
\n
"
);
printf
(
"arg3:
time
kernel
(0=n0, 1=yes
)
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
int
group_count
=
4
;
int
group_count
=
rand
()
%
16
+
1
;
// GEMM shape
// GEMM shape
std
::
vector
<
ck
::
tensor_operation
::
device
::
GemmShape
>
gemm_shapes
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
GemmShape
>
gemm_shapes
;
...
@@ -131,7 +131,7 @@ int main(int argc, char* argv[])
...
@@ -131,7 +131,7 @@ int main(int argc, char* argv[])
std
::
size_t
flop
=
0
,
num_btype
=
0
;
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
in
t
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
for
(
std
::
size_
t
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
{
a_tensors
.
push_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
a_tensors
.
push_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
gemm_shapes
[
i
].
M
,
gemm_shapes
[
i
].
K
,
gemm_shapes
[
i
].
StrideA
,
ALayout
{})));
gemm_shapes
[
i
].
M
,
gemm_shapes
[
i
].
K
,
gemm_shapes
[
i
].
StrideA
,
ALayout
{})));
...
@@ -168,7 +168,7 @@ int main(int argc, char* argv[])
...
@@ -168,7 +168,7 @@ int main(int argc, char* argv[])
}
}
}
}
for
(
in
t
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
for
(
std
::
size_
t
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
{
a_tensors_device
.
emplace_back
(
a_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSpace
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSpace
()));
...
@@ -189,12 +189,17 @@ int main(int argc, char* argv[])
...
@@ -189,12 +189,17 @@ int main(int argc, char* argv[])
auto
b_element_op
=
BElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
c_element_op
=
CElementOp
{};
// do GEMM
auto
gemm
=
DeviceGemmInstance
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
// do GEMM
auto
argument
=
auto
argument
=
gemm
.
MakeArgument
(
p_a
,
p_b
,
p_c
,
gemm_shapes
,
a_element_op
,
b_element_op
,
c_element_op
);
gemm
.
MakeArgument
(
p_a
,
p_b
,
p_c
,
gemm_shapes
,
a_element_op
,
b_element_op
,
c_element_op
);
DeviceMem
gemm_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
gemm
.
SetWorkSpacePointer
(
&
argument
,
gemm_desc_workspace
.
GetDeviceBuffer
());
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
...
@@ -202,7 +207,7 @@ int main(int argc, char* argv[])
...
@@ -202,7 +207,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"
);
"not support this GEMM problem"
);
}
}
float
ave_time
=
invoker
.
Run
(
argument
,
nrepeat
);
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
}
);
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
@@ -211,9 +216,10 @@ int main(int argc, char* argv[])
...
@@ -211,9 +216,10 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
bool
pass
=
true
;
if
(
do_verification
)
if
(
do_verification
)
{
{
for
(
in
t
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
for
(
std
::
size_
t
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
{
c_tensors_device
[
i
]
->
FromDevice
(
c_device_tensors
[
i
].
mData
.
data
());
c_tensors_device
[
i
]
->
FromDevice
(
c_device_tensors
[
i
].
mData
.
data
());
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
...
@@ -227,9 +233,9 @@ int main(int argc, char* argv[])
...
@@ -227,9 +233,9 @@ int main(int argc, char* argv[])
c_element_op
);
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
ck
::
utils
::
check_err
(
c_device_tensors
[
i
].
mData
,
c_host_tensors
[
i
].
mData
);
pass
&=
ck
::
utils
::
check_err
(
c_device_tensors
[
i
].
mData
,
c_host_tensors
[
i
].
mData
);
}
}
}
}
return
0
;
return
pass
?
0
:
1
;
}
}
example/16_gemm_reduce/CMakeLists.txt
View file @
a3b4c5cb
add_example_executable
(
example_gemm_reduce_xdl_fp16 gemm_reduce_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_reduce_xdl_max_fp16 gemm_reduce_xdl_max_fp16.cpp
)
add_example_executable
(
example_gemm_reduce_xdl_mean_squaremean_fp16 gemm_reduce_xdl_mean_squaremean_fp16.cpp
)
example/16_gemm_reduce/gemm_reduce_xdl_fp16.cpp
→
example/16_gemm_reduce/gemm_reduce_xdl_
max_
fp16.cpp
View file @
a3b4c5cb
...
@@ -3,7 +3,8 @@
...
@@ -3,7 +3,8 @@
#include <initializer_list>
#include <initializer_list>
#include <cstdlib>
#include <cstdlib>
#include <stdlib.h>
#include <stdlib.h>
#include <half.hpp>
#include "check_err.hpp"
#include "config.hpp"
#include "config.hpp"
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
...
@@ -20,45 +21,71 @@ using S = ck::Sequence<Is...>;
...
@@ -20,45 +21,71 @@ using S = ck::Sequence<Is...>;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
F32
=
float
;
using
F64
=
double
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
ADataType
=
F16
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
BDataType
=
F16
;
using
CDataType
=
F16
;
using
CDataType
=
F16
;
using
DDataType
=
F32
;
using
GemmAccDataType
=
F32
;
using
ReduceAccDataType
=
F32
;
using
DDataType
=
F64
;
using
DPtrsGlobal
=
ck
::
Tuple
<
DDataType
*>
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
D0ReduceOp
=
ck
::
tensor_operation
::
element_wise
::
ReduceSum
;
using
DsReduceOp
=
ck
::
Tuple
<
ck
::
reduce
::
Max
<
ReduceAccDataType
>>
;
using
D1ReduceOp
=
ck
::
tensor_operation
::
element_wise
::
ReduceSquareSum
;
using
DsElementOp
=
ck
::
Tuple
<
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
ReduceAccDataType
,
ReduceAccDataType
,
false
>>
;
using
DGlobalMemOp
=
ck
::
InMemoryDataOperationEnumSequence
<
ck
::
InMemoryDataOperationEnum
::
AtomicMax
>
;
static
constexpr
auto
GemmSpecialization
=
static
constexpr
auto
GemmSpecialization
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
// clang-format off
using
DeviceGemmReduceInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmReduce_Xdl_CShuffle
using
DeviceGemmReduceInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D
0
| D
1
| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle|
ReduceAcc|
DData| A| B| C| D
xs
|
DxsInEleOp| DxsAccEleOp|
D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce|
Reduce
| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | Type| Type| Type| DataType| DataType|
DataType|
Type
Tuple
| Elementwise| Elementwise| Elementwise|
Reduce|
| | MemoryData
| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | |
|
| Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | |
|
| Operation| Operation| Operation|
Operation|
| |
Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | |
|
| | | |
|
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//######| | | | | | | | |
| |
| | | |
| |
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
AElementOp
,
BElementOp
,
CElementOp
,
D
0
ReduceOp
,
D1Reduce
Op
,
GemmSpecialization
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
;
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
ReduceAccDataType
,
DPtrsGlobal
,
AElementOp
,
BElementOp
,
CElementOp
,
D
s
ReduceOp
,
DsElementOp
,
DsElementOp
,
DGlobalMem
Op
,
GemmSpecialization
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
;
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
BDataType
,
CDataType
,
GemmAccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
DDataType
>
void
DumpGemmLayerNormPerf
(
float
gemm_reduce_time
,
int
M
,
int
N
,
int
K
)
{
std
::
size_t
gemm_flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
gemm_num_byte
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
+
sizeof
(
DDataType
)
*
M
;
float
tflops
=
static_cast
<
float
>
(
gemm_flop
)
/
1.E9
/
gemm_reduce_time
;
float
gemm_gb_per_sec
=
gemm_num_byte
/
1.E6
/
gemm_reduce_time
;
std
::
cout
<<
"gemm + reduceMax Perf: "
<<
gemm_reduce_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gemm_gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
}
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
bool
do_verification
=
1
;
bool
do_verification
=
true
;
int
init_method
=
1
;
int
init_method
=
1
;
int
nrepeat
=
5
;
bool
time_kernel
=
false
;
// GEMM shape
// GEMM shape
ck
::
index_t
M
=
3840
;
ck
::
index_t
M
=
3840
;
...
@@ -77,13 +104,13 @@ int main(int argc, char* argv[])
...
@@ -77,13 +104,13 @@ int main(int argc, char* argv[])
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
10
)
else
if
(
argc
==
10
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
N
=
std
::
stoi
(
argv
[
5
]);
...
@@ -120,22 +147,17 @@ int main(int argc, char* argv[])
...
@@ -120,22 +147,17 @@ int main(int argc, char* argv[])
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
DDataType
>
d0_m_host_result
(
Tensor
<
DDataType
>
d_m_host_result
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
M
)})));
Tensor
<
DDataType
>
d1_m_host_result
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
M
)})));
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
M
)})));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
DDataType
>
d0_m_device_result
(
Tensor
<
DDataType
>
d_m_device_result
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
M
)})));
Tensor
<
DDataType
>
d1_m_device_result
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
M
)})));
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
M
)})));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d0_m: "
<<
d0_m_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_m: "
<<
d_m_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d1_m: "
<<
d1_m_host_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
switch
(
init_method
)
{
{
...
@@ -153,17 +175,16 @@ int main(int argc, char* argv[])
...
@@ -153,17 +175,16 @@ int main(int argc, char* argv[])
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
d0_device_buf
(
sizeof
(
DDataType
)
*
d0_m_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_m_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
d1_device_buf
(
sizeof
(
DDataType
)
*
d1_m_device_result
.
mDesc
.
GetElementSpace
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
d
0_reduce_op
=
D0Reduce
Op
{};
auto
d
s_element_op
=
DsElement
Op
{};
auto
d1_reduce_op
=
D1ReduceOp
{}
;
auto
p_ds_global
=
ck
::
make_tuple
(
static_cast
<
DDataType
*>
(
d_device_buf
.
GetDeviceBuffer
()))
;
// do GEMM
// do GEMM
auto
gemm
=
DeviceGemmReduceInstance
{};
auto
gemm
=
DeviceGemmReduceInstance
{};
...
@@ -171,8 +192,7 @@ int main(int argc, char* argv[])
...
@@ -171,8 +192,7 @@ int main(int argc, char* argv[])
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DDataType
*>
(
d0_device_buf
.
GetDeviceBuffer
()),
p_ds_global
,
static_cast
<
DDataType
*>
(
d1_device_buf
.
GetDeviceBuffer
()),
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -182,8 +202,8 @@ int main(int argc, char* argv[])
...
@@ -182,8 +202,8 @@ int main(int argc, char* argv[])
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
d
0_reduce
_op
,
d
s_element
_op
,
d
1_reduce
_op
);
d
s_element
_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
{
...
@@ -192,47 +212,17 @@ int main(int argc, char* argv[])
...
@@ -192,47 +212,17 @@ int main(int argc, char* argv[])
"not support this GEMM problem"
);
"not support this GEMM problem"
);
}
}
// warm up
// [CAUSION]: launch_and_time_kernel will not initialize D.
invoker
.
Run
(
argument
);
// If we evaluate kernel multiple time but without initialize D. Verification will fail
d_device_buf
.
SetValue
(
ck
::
NumericLimits
<
DDataType
>::
Lowest
());
// timing
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
float
total_time
=
0
;
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
// init DO, D1 to 0
d0_device_buf
.
SetZero
();
d1_device_buf
.
SetZero
();
KernelTimer
timer
;
timer
.
Start
();
invoker
.
Run
(
argument
);
timer
.
End
();
total_time
+=
timer
.
GetElapsedTime
();
}
float
ave_time
=
total_time
/
nrepeat
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
bool
pass
=
true
;
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
if
(
do_verification
)
if
(
do_verification
)
{
{
c_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
c_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
d0_device_buf
.
FromDevice
(
d0_m_device_result
.
mData
.
data
());
d_device_buf
.
FromDevice
(
d_m_device_result
.
mData
.
data
());
d1_device_buf
.
FromDevice
(
d1_m_device_result
.
mData
.
data
());
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
...
@@ -242,25 +232,35 @@ int main(int argc, char* argv[])
...
@@ -242,25 +232,35 @@ int main(int argc, char* argv[])
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
auto
d_reduce_op
=
DsReduceOp
{}[
ck
::
Number
<
0
>
{}];
for
(
int
m
=
0
;
m
<
M
;
++
m
)
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
{
float
d0_acc
=
d0_reduce_op
.
GetReduceZeroValue
();
ReduceAccDataType
d_acc
=
d_reduce_op
.
GetIdentityValue
();
float
d1_acc
=
d1_reduce_op
.
GetReduceZeroValue
();
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
d_reduce_op
(
d_acc
,
c_m_n_host_result
(
m
,
n
));
d0_reduce_op
.
Reduce
(
d0_acc
,
c_m_n_host_result
(
m
,
n
));
d1_reduce_op
.
Reduce
(
d1_acc
,
c_m_n_host_result
(
m
,
n
));
}
d0_m_host_result
(
m
)
=
d0_acc
;
d_m_host_result
(
m
)
=
d_acc
;
d1_m_host_result
(
m
)
=
d1_acc
;
}
}
check_error
(
c_m_n_host_result
,
c_m_n_device_result
);
pass
=
ck
::
utils
::
check_err
(
c_m_n_device_result
.
mData
,
check_error
(
d0_m_host_result
,
d0_m_device_result
);
c_m_n_host_result
.
mData
,
check_error
(
d1_m_host_result
,
d1_m_device_result
);
"Error: Incorrect results c"
)
&&
ck
::
utils
::
check_err
(
d_m_device_result
.
mData
,
d_m_host_result
.
mData
,
"Error: Incorrect results d"
,
1e-3
,
1e-3
);
}
if
(
time_kernel
)
{
float
gemm_reduceMax_ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
true
});
DumpGemmLayerNormPerf
<
ADataType
,
BDataType
,
CDataType
,
DDataType
>
(
gemm_reduceMax_ave_time
,
M
,
N
,
K
);
}
}
return
0
;
return
pass
?
0
:
1
;
}
}
example/16_gemm_reduce/gemm_reduce_xdl_mean_squaremean_fp16.cpp
0 → 100644
View file @
a3b4c5cb
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "device_gemm_reduce_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reduction_operator.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
#include "reduction_operator.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
CDataType
=
F16
;
using
GemmAccDataType
=
F32
;
using
ReduceAccDataType
=
F32
;
using
DDataType
=
F32
;
using
DPtrsGlobal
=
ck
::
Tuple
<
DDataType
*
,
DDataType
*>
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
D0ReduceOp
=
ck
::
reduce
::
Add
<
ReduceAccDataType
>
;
using
D1ReduceOp
=
ck
::
reduce
::
Add
<
ReduceAccDataType
>
;
using
DxsReduceOp
=
ck
::
Tuple
<
D0ReduceOp
,
D1ReduceOp
>
;
using
UnaryIdenticElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
ReduceAccDataType
,
ReduceAccDataType
,
false
>
;
using
UnaryDivElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
ReduceAccDataType
,
ReduceAccDataType
,
true
>
;
using
UnarySquareElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
ReduceAccDataType
,
ReduceAccDataType
,
false
>
;
using
DxsInElementOp
=
ck
::
Tuple
<
UnaryIdenticElementOp
,
UnarySquareElementOp
>
;
using
DxsOutElementOp
=
ck
::
Tuple
<
UnaryDivElementOp
,
UnaryDivElementOp
>
;
using
DGlobalMemOp
=
ck
::
InMemoryDataOperationEnumSequence
<
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
,
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
>
;
static
constexpr
auto
GemmSpecialization
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
using
DeviceGemmReduceInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
AElementOp
,
BElementOp
,
CElementOp
,
DxsReduceOp
,
DxsInElementOp
,
DxsOutElementOp
,
DGlobalMemOp
,
GemmSpecialization
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
GemmAccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
DDataType
>
void
DumpGemmLayerNormPerf
(
float
gemm_reduce_time
,
int
M
,
int
N
,
int
K
)
{
std
::
size_t
gemm_flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
gemm_num_byte
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
+
sizeof
(
DDataType
)
*
M
+
sizeof
(
DDataType
)
*
M
;
float
tflops
=
static_cast
<
float
>
(
gemm_flop
)
/
1.E9
/
gemm_reduce_time
;
float
gemm_gb_per_sec
=
gemm_num_byte
/
1.E6
/
gemm_reduce_time
;
std
::
cout
<<
"gemm + reduce_mean + reduce_mean_square Perf: "
<<
gemm_reduce_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gemm_gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
// GEMM shape
ck
::
index_t
M
=
3840
;
ck
::
index_t
N
=
4096
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
StrideA
=
4096
;
ck
::
index_t
StrideB
=
4096
;
ck
::
index_t
StrideC
=
4096
;
if
(
argc
==
1
)
{
// do nothing
}
else
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
10
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
StrideA
=
std
::
stoi
(
argv
[
7
]);
StrideB
=
std
::
stoi
(
argv
[
8
]);
StrideC
=
std
::
stoi
(
argv
[
9
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=n0, 1=yes)
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
\n
"
);
exit
(
0
);
}
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
}
};
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
DDataType
>
d0_m_host_result
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
M
)})));
Tensor
<
DDataType
>
d1_m_host_result
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
M
)})));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
DDataType
>
d0_m_device_result
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
M
)})));
Tensor
<
DDataType
>
d1_m_device_result
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
M
)})));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d0_m: "
<<
d0_m_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d1_m: "
<<
d1_m_host_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
break
;
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
break
;
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
d0_device_buf
(
sizeof
(
DDataType
)
*
d0_m_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
d1_device_buf
(
sizeof
(
DDataType
)
*
d1_m_device_result
.
mDesc
.
GetElementSpace
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
dxs_global
=
ck
::
make_tuple
(
static_cast
<
DDataType
*>
(
d0_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DDataType
*>
(
d1_device_buf
.
GetDeviceBuffer
()));
auto
dxs_in_element_op
=
DxsInElementOp
{};
auto
dxs_out_element_op
=
DxsOutElementOp
{
M
,
M
};
// do GEMM
auto
gemm
=
DeviceGemmReduceInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
dxs_global
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
,
dxs_in_element_op
,
dxs_out_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
throw
std
::
runtime_error
(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
);
}
// init DO, D1 to 0
d0_device_buf
.
SetZero
();
d1_device_buf
.
SetZero
();
// if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result
// will not be correct. need to set time_kernel = false for correctness test
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
bool
pass
=
true
;
if
(
do_verification
)
{
c_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
d0_device_buf
.
FromDevice
(
d0_m_device_result
.
mData
.
data
());
d1_device_buf
.
FromDevice
(
d1_m_device_result
.
mData
.
data
());
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
auto
d0_reduce_op
=
D0ReduceOp
{};
auto
d1_reduce_op
=
D1ReduceOp
{};
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
float
d0_acc
=
d0_reduce_op
.
GetIdentityValue
();
float
d1_acc
=
d1_reduce_op
.
GetIdentityValue
();
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
float
c_val
=
ck
::
type_convert
<
float
>
(
c_m_n_host_result
(
m
,
n
));
float
d0_val
=
0
;
float
d1_val
=
0
;
dxs_in_element_op
(
ck
::
Number
<
0
>
{})(
d0_val
,
c_val
);
dxs_in_element_op
(
ck
::
Number
<
1
>
{})(
d1_val
,
c_val
);
d0_reduce_op
(
d0_acc
,
d0_val
);
d1_reduce_op
(
d1_acc
,
d1_val
);
}
dxs_out_element_op
(
ck
::
Number
<
0
>
{})(
d0_acc
,
d0_acc
);
dxs_out_element_op
(
ck
::
Number
<
1
>
{})(
d1_acc
,
d1_acc
);
d0_m_host_result
(
m
)
=
ck
::
type_convert
<
DDataType
>
(
d0_acc
);
d1_m_host_result
(
m
)
=
ck
::
type_convert
<
DDataType
>
(
d1_acc
);
}
pass
=
ck
::
utils
::
check_err
(
c_m_n_device_result
.
mData
,
c_m_n_host_result
.
mData
,
"Error: Incorrect results c"
)
&&
ck
::
utils
::
check_err
(
d0_m_device_result
.
mData
,
d0_m_host_result
.
mData
,
"Error: Incorrect results d0"
,
1e-4
,
1e-5
)
&&
ck
::
utils
::
check_err
(
d1_m_device_result
.
mData
,
d1_m_host_result
.
mData
,
"Error: Incorrect results d1"
,
1e-3
,
1e-5
);
}
if
(
time_kernel
)
{
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
true
});
DumpGemmLayerNormPerf
<
ADataType
,
BDataType
,
CDataType
,
DDataType
>
(
ave_time
,
M
,
N
,
K
);
}
return
pass
?
0
:
1
;
}
example/17_convnd_bwd_data_xdl/CMakeLists.txt
View file @
a3b4c5cb
add_example_executable
(
example_convnd_bwd_data_xdl convnd_bwd_data_xdl.cpp
)
add_example_executable
(
example_convnd_bwd_data_xdl convnd_bwd_data_xdl.cpp
)
target_link_libraries
(
example_convnd_bwd_data_xdl PRIVATE conv_
fwd_
util
)
target_link_libraries
(
example_convnd_bwd_data_xdl PRIVATE conv_util
)
example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp
View file @
a3b4c5cb
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include <half.hpp>
#include <half.hpp>
#include "config.hpp"
#include "config.hpp"
#include "conv_
fwd_
util.hpp"
#include "conv_util.hpp"
#include "print.hpp"
#include "print.hpp"
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
...
@@ -87,7 +87,7 @@ void print_use_msg()
...
@@ -87,7 +87,7 @@ void print_use_msg()
{
{
std
::
cout
<<
"arg1: verification (0=no, 1=yes)
\n
"
std
::
cout
<<
"arg1: verification (0=no, 1=yes)
\n
"
<<
"arg2: initialization (0=no init, 1=random value, 2= init to 1 )
\n
"
<<
"arg2: initialization (0=no init, 1=random value, 2= init to 1 )
\n
"
<<
"arg3:
run
kernel
# of times (>1
)
\n
"
<<
"arg3:
time
kernel
(0=n0, 1=yes
)
\n
"
<<
"arg4: N spatial dimensions (default 2)
\n
"
<<
"arg4: N spatial dimensions (default 2)
\n
"
<<
"Following arguments (depending on number of spatial dims):
\n
"
<<
"Following arguments (depending on number of spatial dims):
\n
"
<<
" N, K, C,
\n
"
<<
" N, K, C,
\n
"
...
@@ -105,40 +105,40 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[])
...
@@ -105,40 +105,40 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[])
ck
::
utils
::
conv
::
ConvParams
params
;
ck
::
utils
::
conv
::
ConvParams
params
;
int
arg_idx
=
5
;
int
arg_idx
=
5
;
params
.
num_dim_spatial
=
num_dim_spatial
;
params
.
num_dim_spatial
_
=
num_dim_spatial
;
params
.
N
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
N
_
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
K
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
K
_
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
C
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
C
_
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
filter_spatial_lengths
.
resize
(
num_dim_spatial
);
params
.
filter_spatial_lengths
_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
{
params
.
filter_spatial_lengths
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
filter_spatial_lengths
_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
}
params
.
input_spatial_lengths
.
resize
(
num_dim_spatial
);
params
.
input_spatial_lengths
_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
{
params
.
input_spatial_lengths
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
input_spatial_lengths
_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
}
params
.
conv_filter_strides
.
resize
(
num_dim_spatial
);
params
.
conv_filter_strides
_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
{
params
.
conv_filter_strides
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
conv_filter_strides
_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
}
params
.
conv_filter_dilations
.
resize
(
num_dim_spatial
);
params
.
conv_filter_dilations
_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
{
params
.
conv_filter_dilations
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
conv_filter_dilations
_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
}
params
.
input_left_pads
.
resize
(
num_dim_spatial
);
params
.
input_left_pads
_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
{
params
.
input_left_pads
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
input_left_pads
_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
}
params
.
input_right_pads
.
resize
(
num_dim_spatial
);
params
.
input_right_pads
_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
{
params
.
input_right_pads
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
input_right_pads
_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
}
return
params
;
return
params
;
...
@@ -165,25 +165,25 @@ DeviceConvBwdDataBasePtr get_conv_instance(int num_dim_spatial)
...
@@ -165,25 +165,25 @@ DeviceConvBwdDataBasePtr get_conv_instance(int num_dim_spatial)
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
bool
do_verification
=
0
;
bool
do_verification
=
true
;
int
init_method
=
0
;
int
init_method
=
1
;
int
nrepeat
=
5
;
bool
time_kernel
=
false
;
int
num_dim_spatial
=
2
;
int
num_dim_spatial
=
2
;
ck
::
utils
::
conv
::
ConvParams
params
;
ck
::
utils
::
conv
::
ConvParams
params
;
params
.
C
=
128
;
params
.
C
_
=
128
;
if
(
argc
==
4
)
if
(
argc
==
4
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
>
4
)
else
if
(
argc
>
4
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
num_dim_spatial
=
std
::
stoi
(
argv
[
4
]);
num_dim_spatial
=
std
::
stoi
(
argv
[
4
]);
// check args number
// check args number
int
conv_args
=
3
+
num_dim_spatial
*
6
;
int
conv_args
=
3
+
num_dim_spatial
*
6
;
...
@@ -202,21 +202,21 @@ int main(int argc, char* argv[])
...
@@ -202,21 +202,21 @@ int main(int argc, char* argv[])
exit
(
1
);
exit
(
1
);
}
}
std
::
vector
<
std
::
size_t
>
input_dims
{
static_cast
<
std
::
size_t
>
(
params
.
N
),
std
::
vector
<
std
::
size_t
>
input_dims
{
static_cast
<
std
::
size_t
>
(
params
.
N
_
),
static_cast
<
std
::
size_t
>
(
params
.
C
)};
static_cast
<
std
::
size_t
>
(
params
.
C
_
)};
input_dims
.
insert
(
std
::
end
(
input_dims
),
input_dims
.
insert
(
std
::
end
(
input_dims
),
std
::
begin
(
params
.
input_spatial_lengths
),
std
::
begin
(
params
.
input_spatial_lengths
_
),
std
::
end
(
params
.
input_spatial_lengths
));
std
::
end
(
params
.
input_spatial_lengths
_
));
std
::
vector
<
std
::
size_t
>
filter_dims
{
static_cast
<
std
::
size_t
>
(
params
.
K
),
std
::
vector
<
std
::
size_t
>
filter_dims
{
static_cast
<
std
::
size_t
>
(
params
.
K
_
),
static_cast
<
std
::
size_t
>
(
params
.
C
)};
static_cast
<
std
::
size_t
>
(
params
.
C
_
)};
filter_dims
.
insert
(
std
::
end
(
filter_dims
),
filter_dims
.
insert
(
std
::
end
(
filter_dims
),
std
::
begin
(
params
.
filter_spatial_lengths
),
std
::
begin
(
params
.
filter_spatial_lengths
_
),
std
::
end
(
params
.
filter_spatial_lengths
));
std
::
end
(
params
.
filter_spatial_lengths
_
));
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
=
params
.
GetOutputSpatialLengths
();
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
=
params
.
GetOutputSpatialLengths
();
std
::
vector
<
std
::
size_t
>
output_dims
{
static_cast
<
std
::
size_t
>
(
params
.
N
),
std
::
vector
<
std
::
size_t
>
output_dims
{
static_cast
<
std
::
size_t
>
(
params
.
N
_
),
static_cast
<
std
::
size_t
>
(
params
.
K
)};
static_cast
<
std
::
size_t
>
(
params
.
K
_
)};
output_dims
.
insert
(
std
::
end
(
output_dims
),
output_dims
.
insert
(
std
::
end
(
output_dims
),
std
::
begin
(
output_spatial_lengths
),
std
::
begin
(
output_spatial_lengths
),
std
::
end
(
output_spatial_lengths
));
std
::
end
(
output_spatial_lengths
));
...
@@ -263,16 +263,16 @@ int main(int argc, char* argv[])
...
@@ -263,16 +263,16 @@ int main(int argc, char* argv[])
conv
->
MakeArgumentPointer
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
conv
->
MakeArgumentPointer
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
params
.
N
,
params
.
N
_
,
params
.
K
,
params
.
K
_
,
params
.
C
,
params
.
C
_
,
params
.
input_spatial_lengths
,
params
.
input_spatial_lengths
_
,
params
.
filter_spatial_lengths
,
params
.
filter_spatial_lengths
_
,
output_spatial_lengths
,
output_spatial_lengths
,
params
.
conv_filter_strides
,
params
.
conv_filter_strides
_
,
params
.
conv_filter_dilations
,
params
.
conv_filter_dilations
_
,
params
.
input_left_pads
,
params
.
input_left_pads
_
,
params
.
input_right_pads
,
params
.
input_right_pads
_
,
InElementOp
{},
InElementOp
{},
WeiElementOp
{},
WeiElementOp
{},
OutElementOp
{});
OutElementOp
{});
...
@@ -284,16 +284,16 @@ int main(int argc, char* argv[])
...
@@ -284,16 +284,16 @@ int main(int argc, char* argv[])
"not support this Conv problem"
);
"not support this Conv problem"
);
}
}
float
ave_time
=
invoker
->
Run
(
argument
.
get
(),
nrepeat
);
float
ave_time
=
invoker
->
Run
(
argument
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
}
);
std
::
size_t
flop
=
ck
::
utils
::
conv
::
get_flops
(
std
::
size_t
flop
=
ck
::
utils
::
conv
::
get_flops
(
params
.
N
,
params
.
C
,
params
.
K
,
params
.
filter_spatial_lengths
,
output_spatial_lengths
);
params
.
N
_
,
params
.
C
_
,
params
.
K
_
,
params
.
filter_spatial_lengths
_
,
output_spatial_lengths
);
std
::
size_t
num_btype
=
ck
::
utils
::
conv
::
get_btype
<
InDataType
,
WeiDataType
,
OutDataType
>
(
std
::
size_t
num_btype
=
ck
::
utils
::
conv
::
get_btype
<
InDataType
,
WeiDataType
,
OutDataType
>
(
params
.
N
,
params
.
N
_
,
params
.
C
,
params
.
C
_
,
params
.
K
,
params
.
K
_
,
params
.
input_spatial_lengths
,
params
.
input_spatial_lengths
_
,
params
.
filter_spatial_lengths
,
params
.
filter_spatial_lengths
_
,
output_spatial_lengths
);
output_spatial_lengths
);
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
@@ -310,10 +310,10 @@ int main(int argc, char* argv[])
...
@@ -310,10 +310,10 @@ int main(int argc, char* argv[])
auto
ref_argument
=
ref_conv
.
MakeArgument
(
in_n_c_hi_wi_host_result
,
auto
ref_argument
=
ref_conv
.
MakeArgument
(
in_n_c_hi_wi_host_result
,
wei_k_c_y_x
,
wei_k_c_y_x
,
out_n_k_ho_wo
,
out_n_k_ho_wo
,
params
.
conv_filter_strides
,
params
.
conv_filter_strides
_
,
params
.
conv_filter_dilations
,
params
.
conv_filter_dilations
_
,
params
.
input_left_pads
,
params
.
input_left_pads
_
,
params
.
input_right_pads
,
params
.
input_right_pads
_
,
InElementOp
{},
InElementOp
{},
WeiElementOp
{},
WeiElementOp
{},
OutElementOp
{});
OutElementOp
{});
...
@@ -322,29 +322,30 @@ int main(int argc, char* argv[])
...
@@ -322,29 +322,30 @@ int main(int argc, char* argv[])
in_device_buf
.
FromDevice
(
in_n_c_hi_wi_device_result
.
mData
.
data
());
in_device_buf
.
FromDevice
(
in_n_c_hi_wi_device_result
.
mData
.
data
());
check_error
(
in_n_c_hi_wi_host_result
,
in_n_c_hi_wi_device_result
);
return
ck
::
utils
::
check_err
(
in_n_c_hi_wi_device_result
.
mData
,
in_n_c_hi_wi_host_result
.
mData
)
?
0
:
1
;
};
};
switch
(
num_dim_spatial
)
switch
(
num_dim_spatial
)
{
{
case
3
:
{
case
3
:
{
auto
ref_conv
=
ReferenceConvBwdDataInstance
<
3
>
();
auto
ref_conv
=
ReferenceConvBwdDataInstance
<
3
>
();
verify_f
(
ref_conv
);
return
verify_f
(
ref_conv
);
break
;
}
}
case
2
:
{
case
2
:
{
auto
ref_conv
=
ReferenceConvBwdDataInstance
<
2
>
();
auto
ref_conv
=
ReferenceConvBwdDataInstance
<
2
>
();
verify_f
(
ref_conv
);
return
verify_f
(
ref_conv
);
break
;
}
}
case
1
:
{
case
1
:
{
auto
ref_conv
=
ReferenceConvBwdDataInstance
<
1
>
();
auto
ref_conv
=
ReferenceConvBwdDataInstance
<
1
>
();
verify_f
(
ref_conv
);
return
verify_f
(
ref_conv
);
break
;
}
}
default:
{
default:
{
throw
std
::
runtime_error
(
"Unsupported number of spatial dimensions provided!"
);
throw
std
::
runtime_error
(
"Unsupported number of spatial dimensions provided!"
);
}
}
}
}
}
}
return
0
;
}
}
example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp
View file @
a3b4c5cb
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#include <cstdlib>
#include <cstdlib>
#include <stdlib.h>
#include <stdlib.h>
#include <half.hpp>
#include <half.hpp>
#include "check_err.hpp"
#include "config.hpp"
#include "config.hpp"
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
...
@@ -11,9 +12,9 @@
...
@@ -11,9 +12,9 @@
#include "device_tensor.hpp"
#include "device_tensor.hpp"
#include "device_batched_gemm_reduce_xdl_cshuffle.hpp"
#include "device_batched_gemm_reduce_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "element_wise_operation.hpp"
#include "reduction_operator.hpp"
#include "reference_batched_gemm.hpp"
#include "reference_batched_gemm.hpp"
#include "gemm_specialization.hpp"
#include "gemm_specialization.hpp"
#include "element_wise_reduce_operation.hpp"
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
@@ -24,31 +25,45 @@ using F32 = float;
...
@@ -24,31 +25,45 @@ using F32 = float;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
ADataType
=
F16
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
BDataType
=
F16
;
using
CDataType
=
F16
;
using
CDataType
=
F16
;
using
DDataType
=
F32
;
using
ReduceAccDataType
=
F32
;
using
DDataType
=
F32
;
using
DPtrsGlobal
=
ck
::
Tuple
<
DDataType
*
,
DDataType
*>
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
D0ReduceOp
=
ck
::
tensor_operation
::
element_wise
::
ReduceSum
;
using
D0ReduceOp
=
ck
::
reduce
::
Add
<
ReduceAccDataType
>
;
using
D1ReduceOp
=
ck
::
tensor_operation
::
element_wise
::
ReduceSquareSum
;
using
D1ReduceOp
=
ck
::
reduce
::
Add
<
ReduceAccDataType
>
;
using
DxsReduceOp
=
ck
::
Tuple
<
D0ReduceOp
,
D1ReduceOp
>
;
using
UnaryIdenticElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
ReduceAccDataType
,
ReduceAccDataType
,
false
>
;
using
UnarySquareElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
ReduceAccDataType
,
ReduceAccDataType
,
false
>
;
using
DxsInElementOp
=
ck
::
Tuple
<
UnaryIdenticElementOp
,
UnarySquareElementOp
>
;
using
DxsOutElementOp
=
ck
::
Tuple
<
UnaryIdenticElementOp
,
UnaryIdenticElementOp
>
;
using
DGlobalMemOp
=
ck
::
InMemoryDataOperationEnumSequence
<
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
,
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
>
;
static
constexpr
auto
GemmSpecialization
=
static
constexpr
auto
GemmSpecialization
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
// clang-format off
using
DeviceBatchedGemmReduceInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmReduce_Xdl_CShuffle
using
DeviceBatchedGemmReduceInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D
0
| D
1
| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc|
DData| A| B| C| D
xs
|
DxsInEleOp| DxsAccEleOp|
D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce|
Reduce
| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | Type| Type| Type| DataType| DataType| DataType|
Type
Tuple
| Elementwise| Elementwise| Elementwise|
Reduce|
| | MemoryData
| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation|
Operation|
Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | |
| Operation
| Operation| Operation| Operation|
| |
Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | |
|
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//######| | | | | | | | | |
|
| | | |
| |
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
AElementOp
,
BElementOp
,
CElementOp
,
D
0
ReduceOp
,
D
1Reduce
Op
,
GemmSpecialization
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
;
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
AElementOp
,
BElementOp
,
CElementOp
,
D
xs
ReduceOp
,
D
xsInElementOp
,
DxsOutElementOp
,
DGlobalMem
Op
,
GemmSpecialization
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
;
// clang-format on
// clang-format on
using
ReferenceBatchedGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceBatchedGemmInstance
=
ck
::
tensor_operation
::
host
::
...
@@ -56,18 +71,18 @@ using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
...
@@ -56,18 +71,18 @@ using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
bool
do_verification
=
1
;
bool
do_verification
=
true
;
int
init_method
=
1
;
int
init_method
=
1
;
int
nrepeat
=
5
;
bool
time_kernel
=
false
;
// GEMM shape
// GEMM shape
ck
::
index_t
M
=
3840
;
ck
::
index_t
M
=
2048
;
ck
::
index_t
N
=
4096
;
ck
::
index_t
N
=
1920
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
K
=
2048
;
ck
::
index_t
StrideA
=
4096
;
ck
::
index_t
StrideA
=
2048
;
ck
::
index_t
StrideB
=
4096
;
ck
::
index_t
StrideB
=
2048
;
ck
::
index_t
StrideC
=
4096
;
ck
::
index_t
StrideC
=
1920
;
ck
::
index_t
BatchCount
=
4
;
ck
::
index_t
BatchCount
=
4
;
...
@@ -79,13 +94,13 @@ int main(int argc, char* argv[])
...
@@ -79,13 +94,13 @@ int main(int argc, char* argv[])
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
11
)
else
if
(
argc
==
11
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
N
=
std
::
stoi
(
argv
[
5
]);
...
@@ -95,13 +110,13 @@ int main(int argc, char* argv[])
...
@@ -95,13 +110,13 @@ int main(int argc, char* argv[])
StrideB
=
std
::
stoi
(
argv
[
8
]);
StrideB
=
std
::
stoi
(
argv
[
8
]);
StrideC
=
std
::
stoi
(
argv
[
9
]);
StrideC
=
std
::
stoi
(
argv
[
9
]);
BatchCount
=
std
::
stoi
(
argv
[
9
]);
BatchCount
=
std
::
stoi
(
argv
[
10
]);
}
}
else
else
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3:
run
kernel
# of times (>1
)
\n
"
);
printf
(
"arg3:
time
kernel
(0=n0, 1=yes
)
\n
"
);
printf
(
"arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, BatchCount
\n
"
);
printf
(
"arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, BatchCount
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
...
@@ -171,8 +186,8 @@ int main(int argc, char* argv[])
...
@@ -171,8 +186,8 @@ int main(int argc, char* argv[])
auto
a_element_op
=
AElementOp
{};
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
d
0_reduce_op
=
D0ReduceOp
{};
auto
d
xs_global
=
ck
::
make_tuple
(
static_cast
<
DDataType
*>
(
d0_device_buf
.
GetDeviceBuffer
()),
auto
d1_reduce_op
=
D1ReduceOp
{}
;
static_cast
<
DDataType
*>
(
d1_device_buf
.
GetDeviceBuffer
()))
;
// do GEMM
// do GEMM
auto
batched_gemm
=
DeviceBatchedGemmReduceInstance
{};
auto
batched_gemm
=
DeviceBatchedGemmReduceInstance
{};
...
@@ -181,8 +196,7 @@ int main(int argc, char* argv[])
...
@@ -181,8 +196,7 @@ int main(int argc, char* argv[])
batched_gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
batched_gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DDataType
*>
(
d0_device_buf
.
GetDeviceBuffer
()),
dxs_global
,
static_cast
<
DDataType
*>
(
d1_device_buf
.
GetDeviceBuffer
()),
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -192,8 +206,8 @@ int main(int argc, char* argv[])
...
@@ -192,8 +206,8 @@ int main(int argc, char* argv[])
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
d0_reduce_op
,
DxsInElementOp
{}
,
d1_reduce_op
,
DxsOutElementOp
{}
,
BatchCount
);
BatchCount
);
if
(
!
batched_gemm
.
IsSupportedArgument
(
argument
))
if
(
!
batched_gemm
.
IsSupportedArgument
(
argument
))
...
@@ -203,30 +217,13 @@ int main(int argc, char* argv[])
...
@@ -203,30 +217,13 @@ int main(int argc, char* argv[])
"not support this GEMM problem"
);
"not support this GEMM problem"
);
}
}
// warm up
// init DO, D1 to 0
invoker
.
Run
(
argument
);
d0_device_buf
.
SetZero
();
d1_device_buf
.
SetZero
();
// timing
float
total_time
=
0
;
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
// init DO, D1 to 0
d0_device_buf
.
SetZero
();
d1_device_buf
.
SetZero
();
KernelTimer
timer
;
timer
.
Start
();
// if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result
// will not be correct. need to set time_kernel = false for correctness test
invoker
.
Run
(
argument
);
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
timer
.
End
();
total_time
+=
timer
.
GetElapsedTime
();
}
float
ave_time
=
total_time
/
nrepeat
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
BatchCount
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
BatchCount
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
BatchCount
*
M
*
K
+
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
BatchCount
*
M
*
K
+
...
@@ -240,6 +237,7 @@ int main(int argc, char* argv[])
...
@@ -240,6 +237,7 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
batched_gemm
.
GetTypeString
()
<<
std
::
endl
;
<<
batched_gemm
.
GetTypeString
()
<<
std
::
endl
;
bool
pass
=
true
;
if
(
do_verification
)
if
(
do_verification
)
{
{
c_device_buf
.
FromDevice
(
c_g_m_n_device_result
.
mData
.
data
());
c_device_buf
.
FromDevice
(
c_g_m_n_device_result
.
mData
.
data
());
...
@@ -254,28 +252,47 @@ int main(int argc, char* argv[])
...
@@ -254,28 +252,47 @@ int main(int argc, char* argv[])
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
auto
d0_reduce_op
=
D0ReduceOp
{};
auto
d1_reduce_op
=
D1ReduceOp
{};
for
(
int
batch
=
0
;
batch
<
BatchCount
;
++
batch
)
for
(
int
batch
=
0
;
batch
<
BatchCount
;
++
batch
)
{
{
for
(
int
m
=
0
;
m
<
M
;
++
m
)
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
{
float
d0_acc
=
d0_reduce_op
.
Get
ReduceZero
Value
();
float
d0_acc
=
d0_reduce_op
.
Get
Identity
Value
();
float
d1_acc
=
d1_reduce_op
.
Get
ReduceZero
Value
();
float
d1_acc
=
d1_reduce_op
.
Get
Identity
Value
();
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
{
d0_reduce_op
.
Reduce
(
d0_acc
,
c_g_m_n_host_result
(
batch
,
m
,
n
));
float
c_val
=
ck
::
type_convert
<
float
>
(
c_g_m_n_host_result
(
batch
,
m
,
n
));
d1_reduce_op
.
Reduce
(
d1_acc
,
c_g_m_n_host_result
(
batch
,
m
,
n
));
float
d0_val
=
0
;
float
d1_val
=
0
;
UnaryIdenticElementOp
{}(
d0_val
,
c_val
);
UnarySquareElementOp
{}(
d1_val
,
c_val
);
d0_reduce_op
(
d0_acc
,
d0_val
);
d1_reduce_op
(
d1_acc
,
d1_val
);
}
}
d0_g_m_host_result
(
batch
,
m
)
=
d0_acc
;
d0_g_m_host_result
(
batch
,
m
)
=
ck
::
type_convert
<
DDataType
>
(
d0_acc
)
;
d1_g_m_host_result
(
batch
,
m
)
=
d1_acc
;
d1_g_m_host_result
(
batch
,
m
)
=
ck
::
type_convert
<
DDataType
>
(
d1_acc
)
;
}
}
}
}
check_error
(
c_g_m_n_host_result
,
c_g_m_n_device_result
);
pass
=
ck
::
utils
::
check_err
(
c_g_m_n_host_result
.
mData
,
check_error
(
d0_g_m_host_result
,
d0_g_m_device_result
);
c_g_m_n_device_result
.
mData
,
check_error
(
d1_g_m_host_result
,
d1_g_m_device_result
);
"Error: Incorrect results c"
)
&&
ck
::
utils
::
check_err
(
d0_g_m_device_result
.
mData
,
d0_g_m_host_result
.
mData
,
"Error: Incorrect results! D0"
,
1e-4
,
1e-5
)
&&
ck
::
utils
::
check_err
(
d1_g_m_device_result
.
mData
,
d1_g_m_host_result
.
mData
,
"Error: Incorrect results! D1"
,
1e-3
,
1e-5
);
}
}
return
0
;
return
pass
?
0
:
1
;
}
}
example/19_binary_elementwise/CMakeLists.txt
0 → 100644
View file @
a3b4c5cb
add_example_executable
(
example_broadcast_add_2d_amn_bn broadcast_add_2d_amn_bn.cpp
)
add_example_executable
(
example_broadcast_add_3d_am_bmnk broadcast_add_3d_am_bmnk.cpp
)
add_example_executable
(
example_elementwise_add_1d elementwise_add_1d.cpp
)
add_example_executable
(
example_elementwise_add_4d elementwise_add_4d.cpp
)
\ No newline at end of file
example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp
0 → 100644
View file @
a3b4c5cb
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2022 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include <iostream>
#include <cstdlib>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "binary_element_wise_operation.hpp"
#include "device_binary_elementwise.hpp"
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
ABDataType
=
F16
;
using
CDataType
=
F16
;
using
EltwiseComputeDataType
=
F32
;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
Add
<
EltwiseComputeDataType
,
EltwiseComputeDataType
,
EltwiseComputeDataType
>
;
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise
<
ABDataType
,
ABDataType
,
CDataType
,
EltwiseComputeDataType
,
Add
,
2
,
8
,
8
,
8
,
8
>
;
template
<
typename
HostTensorA
,
typename
HostTensorB
,
typename
HostTensorC
,
typename
ComputeDataType
,
typename
Functor
,
int
broadcastDim
>
void
host_broadcast2D
(
HostTensorC
&
C
,
const
HostTensorA
&
A
,
const
HostTensorB
&
B
,
int
M
,
int
N
,
Functor
functor
)
{
using
ctype
=
ck
::
remove_reference_t
<
decltype
(
C
(
0
,
0
))
>
;
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
ComputeDataType
Amn
=
ck
::
type_convert
<
ComputeDataType
>
(
A
(
m
,
n
));
ComputeDataType
Cmn
=
0
;
if
constexpr
(
broadcastDim
==
0
)
{
ComputeDataType
Bn
=
ck
::
type_convert
<
ComputeDataType
>
(
B
(
n
));
functor
(
Cmn
,
Amn
,
Bn
);
}
else
{
ComputeDataType
Bm
=
ck
::
type_convert
<
ComputeDataType
>
(
B
(
m
));
functor
(
Cmn
,
Amn
,
Bm
);
}
C
(
m
,
n
)
=
ck
::
type_convert
<
ctype
>
(
Cmn
);
}
}
}
int
main
()
{
bool
do_verification
=
true
;
bool
time_kernel
=
false
;
ck
::
index_t
M
=
1024
;
ck
::
index_t
N
=
1024
;
ck
::
index_t
Stride
=
1024
;
auto
f_host_tensor_descriptor1d
=
[](
std
::
size_t
len
,
std
::
size_t
stride
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
len
}),
std
::
vector
<
std
::
size_t
>
({
stride
}));
};
auto
f_host_tensor_descriptor2d
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
};
Tensor
<
ABDataType
>
a_m_n
(
f_host_tensor_descriptor2d
(
M
,
N
,
Stride
));
Tensor
<
ABDataType
>
b_n
(
f_host_tensor_descriptor1d
(
N
,
1
));
Tensor
<
CDataType
>
c_m_n
(
f_host_tensor_descriptor2d
(
M
,
N
,
Stride
));
a_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
ABDataType
>
{
0.0
,
1.0
});
b_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
ABDataType
>
{
0.0
,
1.0
});
DeviceMem
a_m_n_device_buf
(
sizeof
(
ABDataType
)
*
a_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
b_n_device_buf
(
sizeof
(
ABDataType
)
*
b_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n
.
mDesc
.
GetElementSpace
());
a_m_n_device_buf
.
ToDevice
(
a_m_n
.
mData
.
data
());
b_n_device_buf
.
ToDevice
(
b_n
.
mData
.
data
());
auto
broadcastAdd
=
DeviceElementwiseAddInstance
{};
auto
argument
=
broadcastAdd
.
MakeArgumentPointer
(
a_m_n_device_buf
.
GetDeviceBuffer
(),
b_n_device_buf
.
GetDeviceBuffer
(),
c_m_n_device_buf
.
GetDeviceBuffer
(),
{
M
,
N
},
{
Stride
,
1
},
{
0
,
1
},
// broadcast in first dimension
{
Stride
,
1
},
Add
{});
if
(
!
broadcastAdd
.
IsSupportedArgument
(
argument
.
get
()))
{
throw
std
::
runtime_error
(
"The runtime parameters seems not supported by the "
"DeviceBinaryElementwise instance, exiting!"
);
};
auto
broadcastAdd_invoker_ptr
=
broadcastAdd
.
MakeInvokerPointer
();
float
ave_time
=
broadcastAdd_invoker_ptr
->
Run
(
argument
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms"
<<
std
::
endl
;
bool
pass
=
true
;
if
(
do_verification
)
{
c_m_n_device_buf
.
FromDevice
(
c_m_n
.
mData
.
data
());
Tensor
<
CDataType
>
host_c_m_n
(
f_host_tensor_descriptor2d
(
M
,
N
,
Stride
));
host_broadcast2D
<
Tensor
<
ABDataType
>
,
Tensor
<
ABDataType
>
,
Tensor
<
CDataType
>
,
EltwiseComputeDataType
,
Add
,
0
>
(
host_c_m_n
,
a_m_n
,
b_n
,
M
,
N
,
Add
{});
pass
&=
ck
::
utils
::
check_err
(
c_m_n
.
mData
,
host_c_m_n
.
mData
,
"Error: Incorrect results c"
,
1e-3
,
1e-3
);
}
return
pass
?
0
:
1
;
}
example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp
0 → 100644
View file @
a3b4c5cb
#include <iostream>
#include <cstdlib>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "binary_element_wise_operation.hpp"
#include "device_binary_elementwise.hpp"
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
ABDataType
=
F16
;
using
CDataType
=
F16
;
using
EltwiseComputeDataType
=
F32
;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
Add
<
EltwiseComputeDataType
,
EltwiseComputeDataType
,
EltwiseComputeDataType
>
;
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise
<
ABDataType
,
ABDataType
,
CDataType
,
EltwiseComputeDataType
,
Add
,
3
,
8
,
1
,
8
,
8
>
;
template
<
typename
HostTensorA
,
typename
HostTensorB
,
typename
HostTensorC
,
typename
ComputeDataType
,
typename
Functor
>
void
host_broadcast3D_am_bmnk
(
HostTensorC
&
C
,
const
HostTensorA
&
A
,
const
HostTensorB
&
B
,
const
std
::
vector
<
std
::
size_t
>&
shape
,
Functor
functor
)
{
using
ctype
=
ck
::
remove_reference_t
<
decltype
(
C
(
0
,
0
))
>
;
for
(
std
::
size_t
m
=
0
;
m
<
shape
[
0
];
++
m
)
for
(
std
::
size_t
n
=
0
;
n
<
shape
[
1
];
++
n
)
for
(
std
::
size_t
k
=
0
;
k
<
shape
[
2
];
++
k
)
{
ComputeDataType
a_val
=
ck
::
type_convert
<
ComputeDataType
>
(
A
(
m
));
ComputeDataType
b_val
=
ck
::
type_convert
<
ComputeDataType
>
(
B
(
m
,
n
,
k
));
ComputeDataType
c_val
=
0
;
functor
(
c_val
,
a_val
,
b_val
);
C
(
m
,
n
,
k
)
=
ck
::
type_convert
<
ctype
>
(
c_val
);
}
}
int
main
()
{
bool
do_verification
=
true
;
bool
time_kernel
=
false
;
std
::
vector
<
std
::
size_t
>
mnk
=
{
4
,
16
,
32
};
ck
::
index_t
M
=
mnk
[
0
];
Tensor
<
ABDataType
>
a_m
({
M
});
Tensor
<
ABDataType
>
b_m_n_k
(
mnk
);
Tensor
<
CDataType
>
c_m_n_k
(
mnk
);
a_m
.
GenerateTensorValue
(
GeneratorTensor_3
<
ABDataType
>
{
0.0
,
1.0
});
b_m_n_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ABDataType
>
{
0.0
,
1.0
});
DeviceMem
a_m_device_buf
(
sizeof
(
ABDataType
)
*
a_m
.
mDesc
.
GetElementSpace
());
DeviceMem
b_m_n_k_device_buf
(
sizeof
(
ABDataType
)
*
b_m_n_k
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_k_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_k
.
mDesc
.
GetElementSpace
());
a_m_device_buf
.
ToDevice
(
a_m
.
mData
.
data
());
b_m_n_k_device_buf
.
ToDevice
(
b_m_n_k
.
mData
.
data
());
auto
broadcastAdd
=
DeviceElementwiseAddInstance
{};
auto
argument
=
broadcastAdd
.
MakeArgumentPointer
(
a_m_device_buf
.
GetDeviceBuffer
(),
b_m_n_k_device_buf
.
GetDeviceBuffer
(),
c_m_n_k_device_buf
.
GetDeviceBuffer
(),
std
::
vector
<
ck
::
index_t
>
{
mnk
.
begin
(),
mnk
.
end
()},
{
1
,
0
,
0
},
// broadcast A on second and third dimension
std
::
vector
<
ck
::
index_t
>
{
b_m_n_k
.
mDesc
.
GetStrides
().
begin
(),
b_m_n_k
.
mDesc
.
GetStrides
().
end
()},
std
::
vector
<
ck
::
index_t
>
{
c_m_n_k
.
mDesc
.
GetStrides
().
begin
(),
c_m_n_k
.
mDesc
.
GetStrides
().
end
()},
Add
{});
if
(
!
broadcastAdd
.
IsSupportedArgument
(
argument
.
get
()))
{
throw
std
::
runtime_error
(
"The runtime parameters seems not supported by the "
"DeviceBinaryElementwise instance, exiting!"
);
};
auto
broadcastAdd_invoker_ptr
=
broadcastAdd
.
MakeInvokerPointer
();
float
ave_time
=
broadcastAdd_invoker_ptr
->
Run
(
argument
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms"
<<
std
::
endl
;
bool
pass
=
true
;
if
(
do_verification
)
{
c_m_n_k_device_buf
.
FromDevice
(
c_m_n_k
.
mData
.
data
());
Tensor
<
CDataType
>
host_c_m_n_k
(
mnk
);
host_broadcast3D_am_bmnk
<
Tensor
<
ABDataType
>
,
Tensor
<
ABDataType
>
,
Tensor
<
CDataType
>
,
EltwiseComputeDataType
,
Add
>
(
host_c_m_n_k
,
a_m
,
b_m_n_k
,
mnk
,
Add
{});
pass
&=
ck
::
utils
::
check_err
(
c_m_n_k
.
mData
,
host_c_m_n_k
.
mData
,
"Error: Incorrect results c"
,
1e-3
,
1e-3
);
}
return
pass
?
0
:
1
;
}
example/19_binary_elementwise/elementwise_add_1d.cpp
0 → 100644
View file @
a3b4c5cb
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2022 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include <iostream>
#include <cstdlib>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "binary_element_wise_operation.hpp"
#include "device_binary_elementwise.hpp"
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
ABDataType
=
F16
;
using
CDataType
=
F16
;
using
EltwiseComputeDataType
=
F32
;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
Add
<
EltwiseComputeDataType
,
EltwiseComputeDataType
,
EltwiseComputeDataType
>
;
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise
<
ABDataType
,
ABDataType
,
CDataType
,
EltwiseComputeDataType
,
Add
,
1
,
8
,
8
,
8
,
8
>
;
template
<
typename
HostTensorA
,
typename
HostTensorB
,
typename
HostTensorC
,
typename
ComputeDataType
,
typename
Functor
>
void
host_elementwise1D
(
HostTensorC
&
C
,
const
HostTensorA
&
A
,
const
HostTensorB
&
B
,
int
M
,
Functor
functor
)
{
using
ctype
=
ck
::
remove_reference_t
<
decltype
(
C
(
0
))
>
;
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
ComputeDataType
Am
=
ck
::
type_convert
<
ComputeDataType
>
(
A
(
m
));
ComputeDataType
Bm
=
ck
::
type_convert
<
ComputeDataType
>
(
B
(
m
));
ComputeDataType
Cm
=
0
;
functor
(
Cm
,
Am
,
Bm
);
C
(
m
)
=
ck
::
type_convert
<
ctype
>
(
Cm
);
}
}
int
main
()
{
bool
do_verification
=
true
;
bool
time_kernel
=
false
;
ck
::
index_t
M
=
1024
;
auto
f_host_tensor_descriptor1d
=
[](
std
::
size_t
len
,
std
::
size_t
stride
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
len
}),
std
::
vector
<
std
::
size_t
>
({
stride
}));
};
Tensor
<
ABDataType
>
a_m
(
f_host_tensor_descriptor1d
(
M
,
1
));
Tensor
<
ABDataType
>
b_m
(
f_host_tensor_descriptor1d
(
M
,
1
));
Tensor
<
CDataType
>
c_m
(
f_host_tensor_descriptor1d
(
M
,
1
));
a_m
.
GenerateTensorValue
(
GeneratorTensor_3
<
ABDataType
>
{
0.0
,
1.0
});
b_m
.
GenerateTensorValue
(
GeneratorTensor_3
<
ABDataType
>
{
0.0
,
1.0
});
DeviceMem
a_m_device_buf
(
sizeof
(
ABDataType
)
*
a_m
.
mDesc
.
GetElementSpace
());
DeviceMem
b_m_device_buf
(
sizeof
(
ABDataType
)
*
b_m
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_device_buf
(
sizeof
(
CDataType
)
*
c_m
.
mDesc
.
GetElementSpace
());
a_m_device_buf
.
ToDevice
(
a_m
.
mData
.
data
());
b_m_device_buf
.
ToDevice
(
b_m
.
mData
.
data
());
auto
broadcastAdd
=
DeviceElementwiseAddInstance
{};
auto
argument
=
broadcastAdd
.
MakeArgumentPointer
(
a_m_device_buf
.
GetDeviceBuffer
(),
b_m_device_buf
.
GetDeviceBuffer
(),
c_m_device_buf
.
GetDeviceBuffer
(),
{
M
},
{
1
},
{
1
},
{
1
},
Add
{});
if
(
!
broadcastAdd
.
IsSupportedArgument
(
argument
.
get
()))
{
throw
std
::
runtime_error
(
"The runtime parameters seems not supported by the "
"DeviceBinaryElementwise instance, exiting!"
);
};
auto
broadcastAdd_invoker_ptr
=
broadcastAdd
.
MakeInvokerPointer
();
float
ave_time
=
broadcastAdd_invoker_ptr
->
Run
(
argument
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms"
<<
std
::
endl
;
bool
pass
=
true
;
if
(
do_verification
)
{
c_m_device_buf
.
FromDevice
(
c_m
.
mData
.
data
());
Tensor
<
CDataType
>
host_c_m
(
f_host_tensor_descriptor1d
(
M
,
1
));
host_elementwise1D
<
Tensor
<
ABDataType
>
,
Tensor
<
ABDataType
>
,
Tensor
<
CDataType
>
,
EltwiseComputeDataType
,
Add
>
(
host_c_m
,
a_m
,
b_m
,
M
,
Add
{});
pass
&=
ck
::
utils
::
check_err
(
c_m
.
mData
,
host_c_m
.
mData
,
"Error: Incorrect results c"
,
1e-3
,
1e-3
);
}
return
pass
?
0
:
1
;
}
example/19_binary_elementwise/elementwise_add_4d.cpp
0 → 100644
View file @
a3b4c5cb
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include <iostream>
#include <cstdlib>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "binary_element_wise_operation.hpp"
#include "device_binary_elementwise.hpp"
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
ABDataType
=
F16
;
using
CDataType
=
F16
;
using
EltwiseComputeDataType
=
F32
;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
Add
<
EltwiseComputeDataType
,
EltwiseComputeDataType
,
EltwiseComputeDataType
>
;
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise
<
ABDataType
,
ABDataType
,
CDataType
,
EltwiseComputeDataType
,
Add
,
4
,
8
,
8
,
8
,
8
>
;
template
<
typename
HostTensorA
,
typename
HostTensorB
,
typename
HostTensorC
,
typename
ComputeDataType
,
typename
Functor
>
void
host_elementwise4D
(
HostTensorC
&
C
,
const
HostTensorA
&
A
,
const
HostTensorB
&
B
,
const
std
::
vector
<
std
::
size_t
>&
shape
,
Functor
functor
)
{
using
ctype
=
ck
::
remove_reference_t
<
decltype
(
C
(
0
,
0
,
0
,
0
))
>
;
for
(
std
::
size_t
n
=
0
;
n
<
shape
[
0
];
++
n
)
for
(
std
::
size_t
c
=
0
;
c
<
shape
[
1
];
++
c
)
for
(
std
::
size_t
h
=
0
;
h
<
shape
[
2
];
++
h
)
for
(
std
::
size_t
w
=
0
;
w
<
shape
[
3
];
++
w
)
{
ComputeDataType
a_val
=
ck
::
type_convert
<
ComputeDataType
>
(
A
(
n
,
c
,
h
,
w
));
ComputeDataType
b_val
=
ck
::
type_convert
<
ComputeDataType
>
(
B
(
n
,
c
,
h
,
w
));
ComputeDataType
c_val
=
0
;
functor
(
c_val
,
a_val
,
b_val
);
C
(
n
,
c
,
h
,
w
)
=
ck
::
type_convert
<
ctype
>
(
c_val
);
}
}
int
main
()
{
bool
do_verification
=
true
;
bool
time_kernel
=
false
;
std
::
vector
<
std
::
size_t
>
nchw
=
{
4
,
16
,
32
,
32
};
Tensor
<
ABDataType
>
a
(
nchw
);
Tensor
<
ABDataType
>
b
(
nchw
);
Tensor
<
CDataType
>
c
(
nchw
);
a
.
GenerateTensorValue
(
GeneratorTensor_3
<
ABDataType
>
{
0.0
,
1.0
});
b
.
GenerateTensorValue
(
GeneratorTensor_3
<
ABDataType
>
{
0.0
,
1.0
});
DeviceMem
a_device_buf
(
sizeof
(
ABDataType
)
*
a
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
ABDataType
)
*
b
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c
.
mDesc
.
GetElementSpace
());
a_device_buf
.
ToDevice
(
a
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b
.
mData
.
data
());
auto
broadcastAdd
=
DeviceElementwiseAddInstance
{};
auto
argument
=
broadcastAdd
.
MakeArgumentPointer
(
a_device_buf
.
GetDeviceBuffer
(),
b_device_buf
.
GetDeviceBuffer
(),
c_device_buf
.
GetDeviceBuffer
(),
std
::
vector
<
ck
::
index_t
>
{
nchw
.
begin
(),
nchw
.
end
()},
std
::
vector
<
ck
::
index_t
>
{
a
.
mDesc
.
GetStrides
().
begin
(),
a
.
mDesc
.
GetStrides
().
end
()},
std
::
vector
<
ck
::
index_t
>
{
b
.
mDesc
.
GetStrides
().
begin
(),
b
.
mDesc
.
GetStrides
().
end
()},
std
::
vector
<
ck
::
index_t
>
{
c
.
mDesc
.
GetStrides
().
begin
(),
c
.
mDesc
.
GetStrides
().
end
()},
Add
{});
if
(
!
broadcastAdd
.
IsSupportedArgument
(
argument
.
get
()))
{
throw
std
::
runtime_error
(
"The runtime parameters seems not supported by the "
"DeviceBinaryElementwise instance, exiting!"
);
};
auto
broadcastAdd_invoker_ptr
=
broadcastAdd
.
MakeInvokerPointer
();
float
ave_time
=
broadcastAdd_invoker_ptr
->
Run
(
argument
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms"
<<
std
::
endl
;
bool
pass
=
true
;
if
(
do_verification
)
{
c_device_buf
.
FromDevice
(
c
.
mData
.
data
());
Tensor
<
CDataType
>
host_c
(
nchw
);
host_elementwise4D
<
Tensor
<
ABDataType
>
,
Tensor
<
ABDataType
>
,
Tensor
<
CDataType
>
,
EltwiseComputeDataType
,
Add
>
(
host_c
,
a
,
b
,
nchw
,
Add
{});
pass
&=
ck
::
utils
::
check_err
(
c
.
mData
,
host_c
.
mData
,
"Error: Incorrect results c"
,
1e-3
,
1e-3
);
}
return
pass
?
0
:
1
;
}
example/20_convnd_bwd_weight_xdl/CMakeLists.txt
0 → 100644
View file @
a3b4c5cb
add_example_executable
(
example_convnd_bwd_weight_xdl convnd_bwd_weight_xdl.cpp
)
target_link_libraries
(
example_convnd_bwd_weight_xdl PRIVATE conv_util
)
\ No newline at end of file
example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp
0 → 100644
View file @
a3b4c5cb
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "check_err.hpp"
#include "conv_util.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "tensor_layout.hpp"
#include "element_wise_operation.hpp"
#include "device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp"
#include "reference_conv_backward_weight.hpp"
using
InDataType
=
ck
::
half_t
;
using
WeiDataType
=
ck
::
half_t
;
using
OutDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
InElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
WeiElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
ConvBwdWeightDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardWeightSpecialization
::
Default
;
using
DeviceConvBwdWeightBasePtr
=
ck
::
tensor_operation
::
device
::
DeviceConvBwdWeightPtr
<
InElementOp
,
WeiElementOp
,
OutElementOp
>
;
// clang-format off
template
<
ck
::
index_t
NumDimSpatial
>
using
DeviceConvndBwdWeightInstance
=
ck
::
tensor_operation
::
device
::
DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
InDataType
,
// InDataType
WeiDataType
,
// WeiDataType
OutDataType
,
// OutDataType
AccDataType
,
// AccDataType
InElementOp
,
// InElementwiseOperation
WeiElementOp
,
// WeiElementwiseOperation
OutElementOp
,
// OutElementwiseOperation
ConvBwdWeightDefault
,
// ConvolutionBackwardWeightSpecialization
NumDimSpatial
,
// NumDimSpatial
256
,
// BlockSize
128
,
// MPerBlock
128
,
// NPerBlock
4
,
// K0PerBlock
8
,
// K1
32
,
// MPerXdl
32
,
// NPerXdl
2
,
// MXdlPerWave
2
,
// NXdlPerWave
S
<
1
,
4
,
16
,
4
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
0
,
3
,
1
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
,
3
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
2
,
// ABlockTransferDstScalarPerVector_K1
true
,
// ABlockLdsAddExtraM
S
<
1
,
4
,
16
,
4
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
0
,
3
,
1
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
,
3
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
8
,
// BBlockTransferSrcScalarPerVector
2
,
// BBlockTransferDstScalarPerVector_K1
true
,
// BBlockLdsAddExtraN
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
4
>
,
// CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
template
<
ck
::
index_t
NumDimSpatial
>
using
ReferenceConvBwdWeightInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvBwdWeight
<
InDataType
,
WeiDataType
,
OutDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
NumDimSpatial
>
;
void
print_use_msg
()
{
std
::
cout
<<
"arg1: verification (0=no, 1=yes)
\n
"
<<
"arg2: initialization (0=no init, 1=random value, 2= init to 1 )
\n
"
<<
"arg3: time kernel (0=n0, 1=yes)
\n
"
<<
"arg4: is show log (0=no, 1=yes)
\n
"
<<
"arg5: split-k
\n
"
<<
"arg6: N spatial dimensions (default 2)
\n
"
<<
"Following arguments (depending on number of spatial dims):
\n
"
<<
" N, K, C,
\n
"
<<
" <filter spatial dimensions>, (ie Y, X for 2D)
\n
"
<<
" <input image spatial dimensions>, (ie Hi, Wi for 2D)
\n
"
<<
" <strides>, (ie Sy, Sx for 2D)
\n
"
<<
" <dilations>, (ie Dy, Dx for 2D)
\n
"
<<
" <left padding>, (ie LeftPy, LeftPx for 2D)
\n
"
<<
" <right padding>, (ie RightPy, RightPx for 2D)
\n
"
<<
std
::
endl
;
}
ck
::
utils
::
conv
::
ConvParams
parse_conv_params
(
int
num_dim_spatial
,
char
*
argv
[])
{
// (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right)
ck
::
utils
::
conv
::
ConvParams
params
;
int
arg_idx
=
7
;
params
.
num_dim_spatial_
=
num_dim_spatial
;
params
.
N_
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
K_
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
C_
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
filter_spatial_lengths_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
params
.
filter_spatial_lengths_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
params
.
input_spatial_lengths_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
params
.
input_spatial_lengths_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
params
.
conv_filter_strides_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
params
.
conv_filter_strides_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
params
.
conv_filter_dilations_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
params
.
conv_filter_dilations_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
params
.
input_left_pads_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
params
.
input_left_pads_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
params
.
input_right_pads_
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
params
.
input_right_pads_
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
return
params
;
}
DeviceConvBwdWeightBasePtr
get_conv_instance
(
int
num_dim_spatial
)
{
switch
(
num_dim_spatial
)
{
case
3
:
{
return
std
::
make_unique
<
DeviceConvndBwdWeightInstance
<
3
>>
();
}
case
2
:
{
return
std
::
make_unique
<
DeviceConvndBwdWeightInstance
<
2
>>
();
}
case
1
:
{
return
std
::
make_unique
<
DeviceConvndBwdWeightInstance
<
1
>>
();
}
default:
{
throw
std
::
runtime_error
(
"Unsupported number of spatial dimensions provided!"
);
}
}
}
int
main
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
int
num_dim_spatial
=
2
;
int
do_log
=
0
;
int
split_k
=
1
;
ck
::
utils
::
conv
::
ConvParams
params
;
params
.
C_
=
128
;
if
(
argc
==
6
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
do_log
=
std
::
stoi
(
argv
[
4
]);
split_k
=
std
::
stoi
(
argv
[
5
]);
}
else
if
(
argc
>
6
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
do_log
=
std
::
stoi
(
argv
[
4
]);
split_k
=
std
::
stoi
(
argv
[
5
]);
num_dim_spatial
=
std
::
stoi
(
argv
[
6
]);
// check args number
int
conv_args
=
3
+
num_dim_spatial
*
6
;
int
cmdline_nargs
=
conv_args
+
7
;
if
(
cmdline_nargs
!=
argc
)
{
print_use_msg
();
exit
(
1
);
}
params
=
parse_conv_params
(
num_dim_spatial
,
argv
);
}
else
if
(
argc
!=
1
)
{
print_use_msg
();
exit
(
1
);
}
std
::
vector
<
std
::
size_t
>
input_dims
{
static_cast
<
std
::
size_t
>
(
params
.
N_
),
static_cast
<
std
::
size_t
>
(
params
.
C_
)};
input_dims
.
insert
(
std
::
end
(
input_dims
),
std
::
begin
(
params
.
input_spatial_lengths_
),
std
::
end
(
params
.
input_spatial_lengths_
));
std
::
vector
<
std
::
size_t
>
filter_dims
{
static_cast
<
std
::
size_t
>
(
params
.
K_
),
static_cast
<
std
::
size_t
>
(
params
.
C_
)};
filter_dims
.
insert
(
std
::
end
(
filter_dims
),
std
::
begin
(
params
.
filter_spatial_lengths_
),
std
::
end
(
params
.
filter_spatial_lengths_
));
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
=
params
.
GetOutputSpatialLengths
();
std
::
vector
<
std
::
size_t
>
output_dims
{
static_cast
<
std
::
size_t
>
(
params
.
N_
),
static_cast
<
std
::
size_t
>
(
params
.
K_
)};
output_dims
.
insert
(
std
::
end
(
output_dims
),
std
::
begin
(
output_spatial_lengths
),
std
::
end
(
output_spatial_lengths
));
Tensor
<
InDataType
>
in_n_c_hi_wi
(
ck
::
utils
::
conv
::
get_input_host_tensor_descriptor
(
input_dims
,
num_dim_spatial
));
Tensor
<
WeiDataType
>
wei_k_c_y_x_host_result
(
ck
::
utils
::
conv
::
get_filters_host_tensor_descriptor
(
filter_dims
,
num_dim_spatial
));
Tensor
<
WeiDataType
>
wei_k_c_y_x_device_result
(
ck
::
utils
::
conv
::
get_filters_host_tensor_descriptor
(
filter_dims
,
num_dim_spatial
));
Tensor
<
OutDataType
>
out_n_k_ho_wo
(
ck
::
utils
::
conv
::
get_output_host_tensor_descriptor
(
output_dims
,
num_dim_spatial
));
std
::
cout
<<
"in_n_c_hi_wi: "
<<
in_n_c_hi_wi
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"wei_k_c_y_x: "
<<
wei_k_c_y_x_device_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"out_n_k_ho_wo: "
<<
out_n_k_ho_wo
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"in_n_c_hi_wi: "
<<
in_n_c_hi_wi
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"wei_k_c_y_x: "
<<
wei_k_c_y_x_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"out_n_k_ho_wo: "
<<
out_n_k_ho_wo
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
out_n_k_ho_wo
.
GenerateTensorValue
(
GeneratorTensor_2
<
OutDataType
>
{
-
2
,
2
});
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
2
,
2
});
break
;
default:
out_n_k_ho_wo
.
GenerateTensorValue
(
GeneratorTensor_1
<
OutDataType
>
{
1
});
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_1
<
InDataType
>
{
1
});
}
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
wei_k_c_y_x_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
out_n_k_ho_wo
.
mDesc
.
GetElementSpace
());
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
out_device_buf
.
ToDevice
(
out_n_k_ho_wo
.
mData
.
data
());
// reset input to zero
wei_device_buf
.
SetZero
();
// do GEMM
auto
conv
=
get_conv_instance
(
num_dim_spatial
);
auto
invoker
=
conv
->
MakeInvokerPointer
();
auto
argument
=
conv
->
MakeArgumentPointer
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
params
.
N_
,
params
.
K_
,
params
.
C_
,
params
.
input_spatial_lengths_
,
params
.
filter_spatial_lengths_
,
output_spatial_lengths
,
params
.
conv_filter_strides_
,
params
.
conv_filter_dilations_
,
params
.
input_left_pads_
,
params
.
input_right_pads_
,
InElementOp
{},
WeiElementOp
{},
OutElementOp
{},
split_k
);
// alloc work space
size_t
bwd_weight_workspace_size
=
conv
->
GetWorkSpaceSize
(
argument
.
get
());
float
ave_time
=
0.
f
;
if
(
std
::
is_same
<
InDataType
,
ck
::
bhalf_t
>::
value
&&
split_k
>
1
)
{
DeviceMem
wei_work_space_device_buf
(
bwd_weight_workspace_size
);
wei_work_space_device_buf
.
SetZero
();
argument
=
conv
->
MakeArgumentPointer
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
AccDataType
*>
(
wei_work_space_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
params
.
N_
,
params
.
K_
,
params
.
C_
,
params
.
input_spatial_lengths_
,
params
.
filter_spatial_lengths_
,
output_spatial_lengths
,
params
.
conv_filter_strides_
,
params
.
conv_filter_dilations_
,
params
.
input_left_pads_
,
params
.
input_right_pads_
,
InElementOp
{},
WeiElementOp
{},
OutElementOp
{},
split_k
);
if
(
!
conv
->
IsSupportedArgument
(
argument
.
get
()))
{
std
::
cout
<<
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
<<
std
::
endl
;
return
1
;
}
ave_time
=
invoker
->
Run
(
argument
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
}
else
{
if
(
!
conv
->
IsSupportedArgument
(
argument
.
get
()))
{
std
::
cout
<<
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
<<
std
::
endl
;
return
1
;
}
ave_time
=
invoker
->
Run
(
argument
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
}
std
::
size_t
flop
=
ck
::
utils
::
conv
::
get_flops
(
params
.
N_
,
params
.
C_
,
params
.
K_
,
params
.
filter_spatial_lengths_
,
output_spatial_lengths
);
std
::
size_t
num_btype
=
ck
::
utils
::
conv
::
get_btype
<
InDataType
,
WeiDataType
,
OutDataType
>
(
params
.
N_
,
params
.
C_
,
params
.
K_
,
params
.
input_spatial_lengths_
,
params
.
filter_spatial_lengths_
,
output_spatial_lengths
);
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
if
(
do_verification
)
{
auto
verify_f
=
[
&
](
const
auto
&
ref_conv
)
{
auto
ref_invoker
=
ref_conv
.
MakeInvoker
();
auto
ref_argument
=
ref_conv
.
MakeArgument
(
in_n_c_hi_wi
,
wei_k_c_y_x_host_result
,
out_n_k_ho_wo
,
params
.
conv_filter_strides_
,
params
.
conv_filter_dilations_
,
params
.
input_left_pads_
,
params
.
input_right_pads_
,
InElementOp
{},
WeiElementOp
{},
OutElementOp
{});
ref_invoker
.
Run
(
ref_argument
);
wei_device_buf
.
FromDevice
(
wei_k_c_y_x_device_result
.
mData
.
data
());
if
(
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"out: "
,
out_n_k_ho_wo
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"in : "
,
in_n_c_hi_wi
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"wei_device(after): "
,
wei_k_c_y_x_device_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"wei_host : "
,
wei_k_c_y_x_host_result
.
mData
,
","
)
<<
std
::
endl
;
}
return
ck
::
utils
::
check_err
(
wei_k_c_y_x_device_result
.
mData
,
wei_k_c_y_x_host_result
.
mData
)
?
0
:
1
;
};
switch
(
num_dim_spatial
)
{
case
3
:
{
auto
ref_conv
=
ReferenceConvBwdWeightInstance
<
3
>
();
return
verify_f
(
ref_conv
);
}
case
2
:
{
auto
ref_conv
=
ReferenceConvBwdWeightInstance
<
2
>
();
return
verify_f
(
ref_conv
);
}
case
1
:
{
auto
ref_conv
=
ReferenceConvBwdWeightInstance
<
1
>
();
return
verify_f
(
ref_conv
);
}
default:
{
throw
std
::
runtime_error
(
"Unsupported number of spatial dimensions provided!"
);
}
}
}
return
0
;
}
example/21_gemm_layernorm/CMakeLists.txt
0 → 100644
View file @
a3b4c5cb
add_example_executable
(
example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp
)
example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp
0 → 100644
View file @
a3b4c5cb
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "device_5ary_elementwise.hpp"
#include "device_gemm_reduce_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
#include "element_wise_reduce_operation.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
CDataType
=
F16
;
using
GemmAccDataType
=
F32
;
using
ReduceAccDataType
=
F32
;
using
DDataType
=
F32
;
using
DPtrsGlobal
=
ck
::
Tuple
<
DDataType
*
,
DDataType
*>
;
using
GammaDataType
=
F16
;
using
BetaDataType
=
F16
;
using
LayerNormOutDataType
=
F16
;
using
NormalizeComputeDataType
=
F32
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReduceSumOp
=
ck
::
reduce
::
Add
<
ReduceAccDataType
>
;
using
DxsReduceOp
=
ck
::
Tuple
<
ReduceSumOp
,
ReduceSumOp
>
;
using
UnaryIdenticElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
ReduceAccDataType
,
ReduceAccDataType
,
false
>
;
using
UnaryDivElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
ReduceAccDataType
,
ReduceAccDataType
,
true
>
;
using
UnarySquareElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
ReduceAccDataType
,
ReduceAccDataType
,
false
>
;
using
DxsInElementOp
=
ck
::
Tuple
<
UnaryIdenticElementOp
,
UnarySquareElementOp
>
;
using
DxsOutElementOp
=
ck
::
Tuple
<
UnaryDivElementOp
,
UnaryDivElementOp
>
;
using
DxsGlobalMemOp
=
ck
::
InMemoryDataOperationEnumSequence
<
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
,
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
>
;
static
constexpr
auto
GemmSpecialization
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
using
DeviceGemmReduceInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
AElementOp
,
BElementOp
,
CElementOp
,
DxsReduceOp
,
DxsInElementOp
,
DxsOutElementOp
,
DxsGlobalMemOp
,
GemmSpecialization
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
GemmAccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
using
NormalizeFunctor
=
ck
::
tensor_operation
::
element_wise
::
Normalize
;
// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y
using
DeviceNormalizeInstance
=
ck
::
tensor_operation
::
device
::
Device5AryElementwise
<
CDataType
,
DDataType
,
DDataType
,
GammaDataType
,
BetaDataType
,
LayerNormOutDataType
,
NormalizeComputeDataType
,
NormalizeFunctor
,
2
,
8
,
8
,
// scalarPerVector: gemm_out
1
,
// scalarPerVector: reduce_mean
1
,
// scalarPerVector: reduce_mean_square
8
,
// scalarPerVector: Gamma
8
,
// scalarPerVector: Beta
8
>
;
// scalarPerVector: LayerNorm_out
auto
f_host_tensor_descriptor1d
=
[](
std
::
size_t
len
,
std
::
size_t
stride
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
len
}),
std
::
vector
<
std
::
size_t
>
({
stride
}));
};
auto
f_host_tensor_descriptor2d
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
}
};
template
<
typename
CDataType
,
typename
DDataType
,
typename
A_functor
,
typename
B_functor
,
typename
C_functor
>
void
host_gemm_layernorm
(
Tensor
<
LayerNormOutDataType
>&
out_m_n
,
const
Tensor
<
ADataType
>&
a_m_k
,
const
Tensor
<
ADataType
>&
b_k_n
,
const
Tensor
<
GammaDataType
>&
gamma_n
,
const
Tensor
<
GammaDataType
>&
beta_n
,
A_functor
a_element_op
,
B_functor
b_element_op
,
C_functor
c_element_op
,
int
M
,
int
N
)
{
using
out_type
=
ck
::
remove_reference_t
<
decltype
(
out_m_n
(
0
,
0
))
>
;
int
StrideC
=
N
;
Tensor
<
CDataType
>
c_m_n
(
f_host_tensor_descriptor2d
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
DDataType
>
mean_m
(
f_host_tensor_descriptor1d
(
M
,
1
));
Tensor
<
DDataType
>
meanSquare_m
(
f_host_tensor_descriptor1d
(
M
,
1
));
auto
averageOpInst
=
UnaryDivElementOp
{
M
};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
// reduce_mean and reduce_square_mean
auto
reduceSumOpInst
=
ReduceSumOp
{};
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
float
mean_acc
=
reduceSumOpInst
.
GetIdentityValue
();
float
square_mean_acc
=
reduceSumOpInst
.
GetIdentityValue
();
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
ReduceAccDataType
c_val
=
ck
::
type_convert
<
float
>
(
c_m_n
(
m
,
n
));
ReduceAccDataType
square_c_val
=
0
;
UnarySquareElementOp
{}(
square_c_val
,
c_val
);
reduceSumOpInst
(
mean_acc
,
c_val
);
reduceSumOpInst
(
square_mean_acc
,
square_c_val
);
}
averageOpInst
(
mean_acc
,
mean_acc
);
averageOpInst
(
square_mean_acc
,
square_mean_acc
);
mean_m
(
m
)
=
ck
::
type_convert
<
DDataType
>
(
mean_acc
);
meanSquare_m
(
m
)
=
ck
::
type_convert
<
DDataType
>
(
square_mean_acc
);
}
// LayerNorm
auto
layerNormInst
=
NormalizeFunctor
{};
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
float
out_f32
=
0
;
layerNormInst
(
out_f32
,
c_m_n
(
m
,
n
),
mean_m
(
m
),
meanSquare_m
(
m
),
gamma_n
(
n
),
beta_n
(
n
));
out_m_n
(
m
,
n
)
=
static_cast
<
out_type
>
(
out_f32
);
}
}
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
DDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
NormalizeDataType
>
void
DumpGemmLayerNormPerf
(
float
gemm_reduce_time
,
float
normalize_time
,
int
M
,
int
N
,
int
K
)
{
std
::
size_t
gemm_flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
gemm_num_byte
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
+
sizeof
(
DDataType
)
*
M
+
sizeof
(
DDataType
)
*
M
;
std
::
size_t
normalize_num_btye
=
sizeof
(
CDataType
)
*
M
*
N
+
sizeof
(
DDataType
)
*
M
+
sizeof
(
DDataType
)
*
M
+
sizeof
(
GammaDataType
)
*
N
+
sizeof
(
BetaDataType
)
*
N
+
sizeof
(
NormalizeDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
gemm_flop
)
/
1.E9
/
gemm_reduce_time
;
float
gemm_gb_per_sec
=
gemm_num_byte
/
1.E6
/
gemm_reduce_time
;
float
normalize_gb_per_sec
=
normalize_num_btye
/
1.E6
/
normalize_time
;
std
::
cout
<<
"gemm + reduce_mean + reduce_square_mean Perf: "
<<
gemm_reduce_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gemm_gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
std
::
cout
<<
"5-ary elementwise Perf: "
<<
normalize_time
<<
" ms, "
<<
normalize_gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
}
int
main
()
{
// GEMM shape
ck
::
index_t
M
=
1024
;
ck
::
index_t
N
=
1024
;
ck
::
index_t
K
=
1024
;
ck
::
index_t
StrideA
=
1024
;
ck
::
index_t
StrideB
=
1024
;
ck
::
index_t
StrideC
=
1024
;
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor2d
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor2d
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n
(
f_host_tensor_descriptor2d
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
DDataType
>
reduceMean_m
(
f_host_tensor_descriptor1d
(
M
,
1
));
Tensor
<
DDataType
>
reduceMeanSquare_m
(
f_host_tensor_descriptor1d
(
M
,
1
));
Tensor
<
GammaDataType
>
gamma_n
(
f_host_tensor_descriptor1d
(
N
,
1
));
Tensor
<
BetaDataType
>
beta_n
(
f_host_tensor_descriptor1d
(
N
,
1
));
Tensor
<
LayerNormOutDataType
>
layerNorm_m_n
(
f_host_tensor_descriptor2d
(
M
,
N
,
StrideC
,
CLayout
{}));
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
-
1
,
1
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
1
,
1
});
gamma_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
GammaDataType
>
{
-
1
,
1
});
beta_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BetaDataType
>
{
-
1
,
1
});
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
reduceMean_device_buf
(
sizeof
(
DDataType
)
*
reduceMean_m
.
mDesc
.
GetElementSpace
());
DeviceMem
reduceMeanSquare_device_buf
(
sizeof
(
DDataType
)
*
reduceMeanSquare_m
.
mDesc
.
GetElementSpace
());
DeviceMem
gamma_device_buf
(
sizeof
(
GammaDataType
)
*
gamma_n
.
mDesc
.
GetElementSpace
());
DeviceMem
beta_device_buf
(
sizeof
(
BetaDataType
)
*
beta_n
.
mDesc
.
GetElementSpace
());
DeviceMem
layerNorm_device_buf
(
sizeof
(
LayerNormOutDataType
)
*
layerNorm_m_n
.
mDesc
.
GetElementSpace
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
gamma_device_buf
.
ToDevice
(
gamma_n
.
mData
.
data
());
beta_device_buf
.
ToDevice
(
beta_n
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
dxs_global
=
ck
::
make_tuple
(
static_cast
<
DDataType
*>
(
reduceMean_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DDataType
*>
(
reduceMeanSquare_device_buf
.
GetDeviceBuffer
()));
auto
dxs_in_element_op
=
DxsInElementOp
{};
auto
dxs_out_element_op
=
DxsOutElementOp
{
M
,
M
};
// Prepare GEMM, reduce_mean, reduce_mean_square
auto
gemmReduce
=
DeviceGemmReduceInstance
{};
auto
gemmReduce_invoker
=
gemmReduce
.
MakeInvoker
();
auto
gemmReduce_argument
=
gemmReduce
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
dxs_global
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
,
dxs_in_element_op
,
dxs_out_element_op
);
if
(
!
gemmReduce
.
IsSupportedArgument
(
gemmReduce_argument
))
{
throw
std
::
runtime_error
(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
);
}
reduceMean_device_buf
.
SetZero
();
reduceMeanSquare_device_buf
.
SetZero
();
// Prepare LayerNorm
auto
normalize
=
DeviceNormalizeInstance
{};
auto
normalize_invoker
=
normalize
.
MakeInvoker
();
auto
normalize_argument
=
normalize
.
MakeArgument
(
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DDataType
*>
(
reduceMean_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DDataType
*>
(
reduceMeanSquare_device_buf
.
GetDeviceBuffer
()),
static_cast
<
GammaDataType
*>
(
gamma_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BetaDataType
*>
(
beta_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LayerNormOutDataType
*>
(
layerNorm_device_buf
.
GetDeviceBuffer
()),
{
M
,
N
},
{
StrideC
,
1
},
{
1
,
0
},
{
1
,
0
},
{
0
,
1
},
{
0
,
1
},
{
StrideC
,
1
},
NormalizeFunctor
{});
if
(
!
normalize
.
IsSupportedArgument
(
normalize_argument
))
{
throw
std
::
runtime_error
(
"The runtime parameters seems not supported by the "
"Device5AryElementwise instance, exiting!"
);
}
// run kernel
gemmReduce_invoker
.
Run
(
gemmReduce_argument
,
StreamConfig
{
nullptr
,
false
});
normalize_invoker
.
Run
(
normalize_argument
,
StreamConfig
{
nullptr
,
false
});
bool
pass
=
true
;
{
// verification
Tensor
<
LayerNormOutDataType
>
host_layerNorm_m_n
(
f_host_tensor_descriptor2d
(
M
,
N
,
StrideC
,
CLayout
{}));
host_gemm_layernorm
<
CDataType
,
DDataType
>
(
host_layerNorm_m_n
,
a_m_k
,
b_k_n
,
gamma_n
,
beta_n
,
a_element_op
,
b_element_op
,
c_element_op
,
M
,
N
);
layerNorm_device_buf
.
FromDevice
(
layerNorm_m_n
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
layerNorm_m_n
.
mData
,
host_layerNorm_m_n
.
mData
,
"Error: Incorrect results d1"
,
1e-3
,
1e-3
);
}
{
// evaluate kernel perf
bool
time_kernel
=
true
;
float
gemm_reduce_mean_reduce_square_mean_ave_time
=
gemmReduce_invoker
.
Run
(
gemmReduce_argument
,
StreamConfig
{
nullptr
,
time_kernel
});
float
normalize_ave_time
=
normalize_invoker
.
Run
(
normalize_argument
,
StreamConfig
{
nullptr
,
time_kernel
});
if
(
time_kernel
)
DumpGemmLayerNormPerf
<
ADataType
,
BDataType
,
CDataType
,
DDataType
,
GammaDataType
,
BetaDataType
,
LayerNormOutDataType
>
(
gemm_reduce_mean_reduce_square_mean_ave_time
,
normalize_ave_time
,
M
,
N
,
K
);
}
return
pass
?
0
:
1
;
}
example/22_cgemm/CMakeLists.txt
0 → 100644
View file @
a3b4c5cb
add_example_executable
(
example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp
)
example/22_cgemm/cgemm_xdl_fp16.cpp
0 → 100644
View file @
a3b4c5cb
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2022 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "device_cgemm_4gemm_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reference_cgemm.hpp"
#include "gemm_specialization.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
CDataType
=
F16
;
using
AccDataType
=
F32
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
using
DeviceCGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceCGemm_4Gemm_Xdl_CShuffle
<
ALayout
,
// typename ALayout
BLayout
,
// typename BLayout
CLayout
,
// typename CLayout
ADataType
,
// typename ADataType
BDataType
,
// typename BDataType
CDataType
,
// typename CDataType
AccDataType
,
// typename GemmAccDataType
CDataType
,
// typename CShuffleDataType
PassThrough
,
// typename AElementwiseOperation
PassThrough
,
// typename BElementwiseOperation
PassThrough
,
// typename CElementwiseOperation
GemmDefault
,
// GemmSpecialization GemmSpec
1
,
// index_t NumGemmKPrefetchStage
256
,
// index_t BlockSize
256
,
// index_t MPerBlock
128
,
// index_t NPerBlock
32
,
// index_t KPerBlock
8
,
// index_t AK1
8
,
// index_t BK1
32
,
// index_t MPerXDL
32
,
// index_t NPerXDL
4
,
// index_t MXdlPerWave
2
,
// index_t NXdlPerWave
S
<
4
,
64
,
1
>
,
// typename ABlockTransferThreadClusterLengths_AK0_M_AK1
S
<
1
,
0
,
2
>
,
// typename ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// typename ABlockTransferSrcAccessOrder
2
,
// index_t ABlockTransferSrcVectorDim
8
,
// index_t ABlockTransferSrcScalarPerVector
8
,
// index_t ABlockTransferDstScalarPerVector_AK1
1
,
// index_t ABlockLdsExtraM
S
<
4
,
64
,
1
>
,
// typename BBlockTransferThreadClusterLengths_BK0_N_BK1
S
<
1
,
0
,
2
>
,
// typename BBlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// typename BBlockTransferSrcAccessOrder
2
,
// index_t BBlockTransferSrcVectorDim
8
,
// index_t BBlockTransferSrcScalarPerVector
8
,
// index_t BBlockTransferDstScalarPerVector_BK1
1
,
// index_t BBlockLdsExtraN
1
,
// index_t CShuffleMXdlPerWavePerShuffle
1
,
// index_t CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
>
;
// index_t CShuffleBlockTransferScalarPerVector_NPerBlock
// clang-format on
using
ReferenceCGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceCGemm
<
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
int
main
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
// CGEMM shape
ck
::
index_t
M
=
3840
;
ck
::
index_t
N
=
4096
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
StrideA
=
4096
;
ck
::
index_t
StrideB
=
4096
;
ck
::
index_t
StrideC
=
4096
;
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
10
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
StrideA
=
std
::
stoi
(
argv
[
7
]);
StrideB
=
std
::
stoi
(
argv
[
8
]);
StrideC
=
std
::
stoi
(
argv
[
9
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: run kernel # of times (>1)
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
\n
"
);
exit
(
0
);
}
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
}
};
Tensor
<
ADataType
>
a_m_k_real
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
ADataType
>
a_m_k_imag
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n_real
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_k_n_imag
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n_real_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_imag_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
std
::
cout
<<
"a_m_k_real: "
<<
a_m_k_real
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_m_k_imag: "
<<
a_m_k_imag
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n_real: "
<<
b_k_n_real
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n_imag: "
<<
b_k_n_imag
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n_real: "
<<
c_m_n_real_device_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n_imag: "
<<
c_m_n_imag_device_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_m_k_real
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
a_m_k_imag
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b_k_n_real
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
2
,
2
});
b_k_n_imag
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
2
,
2
});
break
;
default:
a_m_k_real
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
-
0.5
,
0.5
});
a_m_k_imag
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
-
0.5
,
0.5
});
b_k_n_real
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
b_k_n_imag
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
}
auto
cgemm
=
DeviceCGemmInstance
{};
DeviceMem
a_m_k_real_device_buf
(
sizeof
(
ADataType
)
*
a_m_k_real
.
mDesc
.
GetElementSpace
());
DeviceMem
a_m_k_imag_device_buf
(
sizeof
(
ADataType
)
*
a_m_k_imag
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_real_device_buf
(
sizeof
(
BDataType
)
*
b_k_n_real
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_imag_device_buf
(
sizeof
(
BDataType
)
*
b_k_n_imag
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_real_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_real_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_imag_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_imag_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
workspace_device_buf
(
cgemm
.
GetWorkspaceSize
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
));
a_m_k_real_device_buf
.
ToDevice
(
a_m_k_real
.
mData
.
data
());
a_m_k_imag_device_buf
.
ToDevice
(
a_m_k_imag
.
mData
.
data
());
b_k_n_real_device_buf
.
ToDevice
(
b_k_n_real
.
mData
.
data
());
b_k_n_imag_device_buf
.
ToDevice
(
b_k_n_imag
.
mData
.
data
());
auto
a_element_op
=
PassThrough
{};
auto
b_element_op
=
PassThrough
{};
auto
c_element_op
=
PassThrough
{};
// do GEMM
auto
invoker
=
cgemm
.
MakeInvoker
();
auto
argument
=
cgemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_m_k_real_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ADataType
*>
(
a_m_k_imag_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_real_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_imag_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_real_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_imag_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
workspace_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
);
if
(
!
cgemm
.
IsSupportedArgument
(
argument
))
{
throw
std
::
runtime_error
(
"wrong! device_cgemm with the specified compilation parameters does "
"not support this CGEMM problem"
);
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
8
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
std
::
size_t
(
2
)
*
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
);
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
cgemm
.
GetTypeString
()
<<
std
::
endl
;
c_m_n_real_device_buf
.
FromDevice
(
c_m_n_real_device_result
.
mData
.
data
());
c_m_n_imag_device_buf
.
FromDevice
(
c_m_n_imag_device_result
.
mData
.
data
());
if
(
do_verification
)
{
Tensor
<
CDataType
>
c_m_n_real_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_imag_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
auto
ref_cgemm
=
ReferenceCGemmInstance
{};
auto
ref_invoker
=
ref_cgemm
.
MakeInvoker
();
auto
ref_argument
=
ref_cgemm
.
MakeArgument
(
a_m_k_real
,
a_m_k_imag
,
b_k_n_real
,
b_k_n_imag
,
c_m_n_real_host_result
,
c_m_n_imag_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
ck
::
utils
::
check_err
(
c_m_n_real_device_result
.
mData
,
c_m_n_real_host_result
.
mData
,
"Verification error: incorrect results in real part!"
,
1e-2
f
,
1e-1
f
);
ck
::
utils
::
check_err
(
c_m_n_imag_device_result
.
mData
,
c_m_n_imag_host_result
.
mData
,
"Verification error: incorrect results in imaginary part!"
,
1e-2
f
,
1e-1
f
);
}
return
0
;
}
example/CMakeLists.txt
View file @
a3b4c5cb
include_directories
(
BEFORE
include_directories
(
BEFORE
${
PROJECT_SOURCE_DIR
}
/include/ck
${
PROJECT_SOURCE_DIR
}
/include/ck
${
PROJECT_SOURCE_DIR
}
/include/ck/utility
${
PROJECT_SOURCE_DIR
}
/include/ck/utility
${
PROJECT_SOURCE_DIR
}
/include/ck/host_utility
${
PROJECT_SOURCE_DIR
}
/include/ck/tensor_description
${
PROJECT_SOURCE_DIR
}
/include/ck/tensor_description
${
PROJECT_SOURCE_DIR
}
/include/ck/tensor
${
PROJECT_SOURCE_DIR
}
/include/ck/tensor
${
PROJECT_SOURCE_DIR
}
/include/ck/problem_transform
${
PROJECT_SOURCE_DIR
}
/include/ck/problem_transform
...
@@ -19,13 +20,22 @@ include_directories(BEFORE
...
@@ -19,13 +20,22 @@ include_directories(BEFORE
add_custom_target
(
examples
)
add_custom_target
(
examples
)
function
(
add_example_executable EXAMPLE_NAME
)
function
(
add_example_executable EXAMPLE_NAME
FILE_NAME
)
message
(
"adding example
${
EXAMPLE_NAME
}
"
)
message
(
"adding example
${
EXAMPLE_NAME
}
"
)
add_executable
(
${
EXAMPLE_NAME
}
${
ARGN
}
)
add_executable
(
${
EXAMPLE_NAME
}
${
FILE_NAME
}
)
target_link_libraries
(
${
EXAMPLE_NAME
}
PRIVATE host_tensor
)
target_link_libraries
(
${
EXAMPLE_NAME
}
PRIVATE host_tensor
)
add_test
(
NAME
${
EXAMPLE_NAME
}
COMMAND $<TARGET_FILE:
${
EXAMPLE_NAME
}
>
${
ARGN
}
)
add_dependencies
(
examples
${
EXAMPLE_NAME
}
)
add_dependencies
(
examples
${
EXAMPLE_NAME
}
)
add_dependencies
(
check
${
EXAMPLE_NAME
}
)
endfunction
(
add_example_executable EXAMPLE_NAME
)
endfunction
(
add_example_executable EXAMPLE_NAME
)
function
(
add_example_executable_no_testing EXAMPLE_NAME FILE_NAME
)
message
(
"adding example
${
EXAMPLE_NAME
}
"
)
add_executable
(
${
EXAMPLE_NAME
}
${
FILE_NAME
}
)
target_link_libraries
(
${
EXAMPLE_NAME
}
PRIVATE host_tensor
)
add_dependencies
(
examples
${
EXAMPLE_NAME
}
)
endfunction
(
add_example_executable_no_testing EXAMPLE_NAME
)
add_subdirectory
(
01_gemm
)
add_subdirectory
(
01_gemm
)
add_subdirectory
(
02_gemm_alpha_beta
)
add_subdirectory
(
02_gemm_alpha_beta
)
add_subdirectory
(
03_gemm_bias_relu
)
add_subdirectory
(
03_gemm_bias_relu
)
...
@@ -38,7 +48,11 @@ add_subdirectory(11_conv2d_bwd_weight)
...
@@ -38,7 +48,11 @@ add_subdirectory(11_conv2d_bwd_weight)
add_subdirectory
(
12_reduce
)
add_subdirectory
(
12_reduce
)
add_subdirectory
(
13_pool2d_fwd
)
add_subdirectory
(
13_pool2d_fwd
)
add_subdirectory
(
14_gemm_xdl_requant_relu_requant
)
add_subdirectory
(
14_gemm_xdl_requant_relu_requant
)
add_subdirectory
(
17_convnd_bwd_data_xdl
)
add_subdirectory
(
15_grouped_gemm
)
add_subdirectory
(
15_grouped_gemm
)
add_subdirectory
(
16_gemm_reduce
)
add_subdirectory
(
16_gemm_reduce
)
add_subdirectory
(
17_convnd_bwd_data_xdl
)
add_subdirectory
(
18_batched_gemm_reduce
)
add_subdirectory
(
18_batched_gemm_reduce
)
add_subdirectory
(
19_binary_elementwise
)
add_subdirectory
(
20_convnd_bwd_weight_xdl
)
add_subdirectory
(
21_gemm_layernorm
)
add_subdirectory
(
22_cgemm
)
Prev
1
2
3
4
5
6
7
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment