Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
211dae82
Commit
211dae82
authored
Oct 27, 2021
by
ltqin
Browse files
Merge branch 'develop' into miopen_downstream_all
parents
5890e300
d5297aba
Changes
65
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4561 additions
and
509 deletions
+4561
-509
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+0
-280
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
...on_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
+62
-31
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
...on_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
+106
-46
host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp
host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp
+463
-0
host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp
host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp
+263
-0
host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp
host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp
+463
-0
host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp
host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp
+263
-0
host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp
host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp
+463
-0
host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp
host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp
+291
-0
host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp
host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp
+564
-0
host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp
host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp
+347
-0
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
+130
-46
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
+209
-0
host/driver_offline/src/conv_bwd_driver_offline.cpp
host/driver_offline/src/conv_bwd_driver_offline.cpp
+69
-44
host/driver_offline/src/conv_fwd_driver_offline.cpp
host/driver_offline/src/conv_fwd_driver_offline.cpp
+34
-61
host/driver_offline/src/conv_wrw_driver_offline.cpp
host/driver_offline/src/conv_wrw_driver_offline.cpp
+436
-0
host/driver_offline/src/gemm_driver_offline.cpp
host/driver_offline/src/gemm_driver_offline.cpp
+288
-0
host/host_tensor/include/device.hpp
host/host_tensor/include/device.hpp
+5
-1
host/host_tensor/include/gemm_common.hpp
host/host_tensor/include/gemm_common.hpp
+16
-0
host/host_tensor/include/host_conv_bwd_weight.hpp
host/host_tensor/include/host_conv_bwd_weight.hpp
+89
-0
No files found.
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
deleted
100644 → 0
View file @
5890e300
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
template
<
typename
TInWei
,
typename
TAcc
,
typename
TOut
,
typename
InLengths
,
typename
WeiLengths
,
typename
OutLengths
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
(
const
InLengths
&
in_n_c_hi_wi_lengths
,
const
WeiLengths
&
wei_k_c_y_x_lengths
,
const
OutLengths
&
out_n_k_ho_wo_lengths
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
Tensor
<
TInWei
>&
in_n_c_hi_wi
,
const
Tensor
<
TInWei
>&
wei_k_c_y_x
,
Tensor
<
TOut
>&
out_n_k_ho_wo
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
auto
I8
=
Number
<
8
>
{};
DeviceMem
in_n_c_hi_wi_device_buf
(
sizeof
(
TInWei
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c_y_x_device_buf
(
sizeof
(
TInWei
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_k_ho_wo_device_buf
(
sizeof
(
TOut
)
*
out_n_k_ho_wo
.
mDesc
.
GetElementSpace
());
in_n_c_hi_wi_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
wei_k_c_y_x_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
out_n_k_ho_wo_device_buf
.
ToDevice
(
out_n_k_ho_wo
.
mData
.
data
());
const
auto
in_n_c_hi_wi_desc
=
make_naive_tensor_descriptor_packed
(
in_n_c_hi_wi_lengths
);
const
auto
wei_k_c_y_x_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_c_y_x_lengths
);
const
auto
out_n_k_ho_wo_desc
=
make_naive_tensor_descriptor_packed
(
out_n_k_ho_wo_lengths
);
#if 0
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 8;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif
0
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
64
;
constexpr
index_t
GemmNPerWave
=
64
;
constexpr
index_t
GemmKPack
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_KPack
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
32
,
2
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_KPack
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
#elif 0
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
64
;
constexpr
index_t
GemmNPerWave
=
64
;
constexpr
index_t
GemmKPack
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_KPack
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
32
,
2
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_KPack
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 4]
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
64
;
constexpr
index_t
GemmNPerWave
=
64
;
constexpr
index_t
GemmKPack
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
4
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_KPack
=
4
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_KPack
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 4]
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
64
;
constexpr
index_t
GemmNPerWave
=
64
;
constexpr
index_t
GemmKPack
=
4
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
1
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
4
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_KPack
=
4
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_KPack
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
#endif
const
auto
descs
=
#if 1
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad
#else
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_1x1
#endif
<
TInWei
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmMPerWave
,
GemmNPerWave
,
GemmKPack
>
(
wei_k_c_y_x_desc
,
in_n_c_hi_wi_desc
,
out_n_k_ho_wo_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
);
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
#if 0
float ave_time = launch_kernel_gemm_xdlops_v1
#else
float
ave_time
=
launch_kernel_gemm_xdlops_v2
#endif
<
BlockSize
,
TInWei
,
TAcc
,
TOut
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
descs
[
I0
]),
decltype
(
descs
[
I1
]),
decltype
(
descs
[
I2
]),
decltype
(
descs
[
I3
]),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerWave
,
GemmNPerWave
,
GemmKPack
,
MRepeat
,
NRepeat
,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferDstScalarPerVector_KPack
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
1
,
0
,
2
>
,
1
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_KPack
,
false
,
// don't move back src coordinate after threadwise copy, which will be fused
// with MoveSrcSliceWindow() to save addr computation
Sequence
<
2
,
3
,
0
,
1
>
,
3
,
GemmCThreadTransferDstScalarPerVector_GemmN1
,
decltype
(
descs
[
I4
]),
decltype
(
descs
[
I5
]),
decltype
(
descs
[
I6
]),
decltype
(
descs
[
I7
]),
decltype
(
descs
[
I8
])
>
(
static_cast
<
TInWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
descs
[
I0
],
descs
[
I1
],
descs
[
I2
],
descs
[
I3
],
descs
[
I4
],
descs
[
I5
],
descs
[
I6
],
descs
[
I7
],
descs
[
I8
],
nrepeat
);
float
perf
=
(
float
)
calculate_convolution_flops
(
in_n_c_hi_wi_desc
,
wei_k_c_y_x_desc
,
out_n_k_ho_wo_desc
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
// copy result back to host
out_n_k_ho_wo_device_buf
.
FromDevice
(
out_n_k_ho_wo
.
mData
.
data
());
}
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
View file @
211dae82
...
...
@@ -47,7 +47,35 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
const
auto
wei_k_c_y_x_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_c_y_x_lengths
);
const
auto
out_n_k_ho_wo_desc
=
make_naive_tensor_descriptor_packed
(
out_n_k_ho_wo_lengths
);
#if 1
#if 0
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif
1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
...
...
@@ -92,36 +120,39 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
const
auto
out_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{}));
constexpr
auto
out_m0_
m1_m2_n
_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 1+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
// 1-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
out_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
...
...
@@ -169,7 +200,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
GemmCThreadTransferDstScalarPerVector
,
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_step_hacks
),
decltype
(
out_m0_
m1_m2_n
_grid_step_hacks
),
decltype
(
out_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_step_hacks
),
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
false
>
(
static_cast
<
TInWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
...
...
@@ -180,7 +211,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
out_gemmm_gemmn_grid_desc
,
wei_gemmk0_gemmm_gemmk1_grid_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_step_hacks
,
out_m0_
m1_m2_n
_grid_step_hacks
,
out_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_step_hacks
,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
...
...
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
View file @
211dae82
...
...
@@ -49,15 +49,15 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
// [M, N, K0, K1] = [256, 128, 4, 4]
, C = 128,
for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPer
Wave
= 32;
constexpr index_t GemmNPer
Wave
= 32;
constexpr index_t GemmMPer
XDL
= 32;
constexpr index_t GemmNPer
XDL
= 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
...
...
@@ -77,15 +77,15 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif
0
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
// [M, N, K0, K1] = [128, 128, 4, 4]
, C = 128,
for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPer
Wave
=
32
;
constexpr
index_t
GemmNPer
Wave
=
32
;
constexpr
index_t
GemmMPer
XDL
=
32
;
constexpr
index_t
GemmNPer
XDL
=
32
;
constexpr
index_t
GemmK1
=
4
;
constexpr
index_t
MRepeat
=
2
;
...
...
@@ -105,15 +105,15 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [256, 256, 4, 8] for fp16
// [M, N, K0, K1] = [256, 256, 4, 8]
, C = 256,
for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPer
Wave
=
32
;
constexpr
index_t
GemmNPer
Wave
=
32
;
constexpr
index_t
GemmMPer
XDL
=
32
;
constexpr
index_t
GemmNPer
XDL
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
4
;
...
...
@@ -133,15 +133,15 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
// [M, N, K0, K1] = [256, 128, 4, 8]
, C = 128,
for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPer
Wave
=
32
;
constexpr
index_t
GemmNPer
Wave
=
32
;
constexpr
index_t
GemmMPer
XDL
=
32
;
constexpr
index_t
GemmNPer
XDL
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
4
;
...
...
@@ -161,15 +161,15 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
// [M, N, K0, K1] = [128, 256, 4, 8]
, C = 128,
for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPer
Wave
=
32
;
constexpr
index_t
GemmNPer
Wave
=
32
;
constexpr
index_t
GemmMPer
XDL
=
32
;
constexpr
index_t
GemmNPer
XDL
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
...
...
@@ -188,16 +188,16 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif
1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
#elif
0
// [M, N, K0, K1] = [128, 128, 4, 8]
, C = 64,
for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPer
Wave
=
32
;
constexpr
index_t
GemmNPer
Wave
=
32
;
constexpr
index_t
GemmMPer
XDL
=
32
;
constexpr
index_t
GemmNPer
XDL
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
...
...
@@ -215,6 +215,62 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#endif
...
...
@@ -249,23 +305,23 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
out_m0_
m1_m2_n
_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M
Repeat
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N
Repeat
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M
Waves
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N
Waves
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M
0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M
1
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M
2
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N
1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M
Repeat
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N
Repeat
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M
Waves
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N
Waves
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M
0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M
1
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M
2
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N
1
constexpr
auto
out_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M
0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N
0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M
1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N
1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M
2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M
3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M
4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N
2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M
0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N
0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M
1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N
1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M
2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M
3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M
4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N
2
constexpr
auto
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
{};
...
...
@@ -287,8 +343,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPer
Wave
,
GemmNPer
Wave
,
GemmMPer
XDL
,
GemmNPer
XDL
,
GemmK1
,
MRepeat
,
NRepeat
,
...
...
@@ -313,19 +369,23 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
GemmCThreadTransferDstScalarPerVector
,
decltype
(
in_gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
),
decltype
(
out_m0_
m1_m2_n
_grid_step_hacks
),
decltype
(
out_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_step_hacks
),
decltype
(
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
false
,
// CAccessOrderMRepeatNRepeat
true
,
// ABlockLdsExtraM
true
// BBlockLdsExtraN
>
(
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
in_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
out_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
in_gemmk0_gemmm_gemmk1_grid_step_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
,
out_m0_
m1_m2_n
_grid_step_hacks
,
out_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_step_hacks
,
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
...
...
host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp
0 → 100644
View file @
211dae82
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
ABType
,
typename
AccType
,
typename
CType
>
void
device_gemm_xdlops_km_kn_mn
(
const
Tensor
<
ABType
>&
a_k_m
,
const
Tensor
<
ABType
>&
b_k_n
,
Tensor
<
CType
>&
c_m_n
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
DeviceMem
a_k_m_device_buf
(
sizeof
(
ABType
)
*
a_k_m
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
ABType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CType
)
*
c_m_n
.
mDesc
.
GetElementSpace
());
a_k_m_device_buf
.
ToDevice
(
a_k_m
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
c_m_n_device_buf
.
ToDevice
(
c_m_n
.
mData
.
data
());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 4], C = 128, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
1
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
1
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
256
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
1
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
1
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#endif
const
auto
K
=
a_k_m
.
mDesc
.
GetLengths
()[
0
];
const
auto
M
=
a_k_m
.
mDesc
.
GetLengths
()[
1
];
const
auto
N
=
b_k_n
.
mDesc
.
GetLengths
()[
1
];
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
const
auto
a_k0_m_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
K1Number
*
a_k_m
.
mDesc
.
GetStrides
()[
0
],
a_k_m
.
mDesc
.
GetStrides
()[
1
],
a_k_m
.
mDesc
.
GetStrides
()[
0
]));
const
auto
b_k0_n_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
K1Number
*
b_k_n
.
mDesc
.
GetStrides
()[
0
],
b_k_n
.
mDesc
.
GetStrides
()[
1
],
b_k_n
.
mDesc
.
GetStrides
()[
0
]));
const
auto
c_m_n_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_m_n
.
mDesc
.
GetStrides
()[
0
],
c_m_n
.
mDesc
.
GetStrides
()[
1
]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: M
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: M
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: N
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: N
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
ABType
,
AccType
,
CType
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
c_m_n_grid_desc
),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
ABlockTransferSrcScalarPerVector_M
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
BBlockTransferSrcScalarPerVector_N
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
0
,
2
,
4
,
5
,
6
,
1
,
3
,
7
>
,
7
,
CThreadTransferDstScalarPerVector
,
decltype
(
a_k0_m_k1_grid_step_hacks
),
decltype
(
b_k0_n_k1_grid_step_hacks
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
false
,
// CAccessOrderMRepeatNRepeat
true
,
// ABlockLdsExtraM
true
// BBlockLdsExtraN
>
(
static_cast
<
ABType
*>
(
a_k_m_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
a_k0_m_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
a_k0_m_k1_grid_move_slice_window_step_hacks
,
b_k0_n_k1_grid_move_slice_window_step_hacks
,
nrepeat
);
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
M
*
N
*
K
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
// copy result back to host
c_m_n_device_buf
.
FromDevice
(
c_m_n
.
mData
.
data
());
}
host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp
0 → 100644
View file @
211dae82
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
ABType
,
typename
AccType
,
typename
CType
>
void
device_gemm_xdlops_km_kn_nm
(
const
Tensor
<
ABType
>&
a_k_m
,
const
Tensor
<
ABType
>&
b_k_n
,
Tensor
<
CType
>&
c_n_m
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
DeviceMem
a_k_m_device_buf
(
sizeof
(
ABType
)
*
a_k_m
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
ABType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_n_m_device_buf
(
sizeof
(
CType
)
*
c_n_m
.
mDesc
.
GetElementSpace
());
a_k_m_device_buf
.
ToDevice
(
a_k_m
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
c_n_m_device_buf
.
ToDevice
(
c_n_m
.
mData
.
data
());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
256
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#endif
const
auto
K
=
a_k_m
.
mDesc
.
GetLengths
()[
0
];
const
auto
M
=
a_k_m
.
mDesc
.
GetLengths
()[
1
];
const
auto
N
=
b_k_n
.
mDesc
.
GetLengths
()[
1
];
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
const
auto
a_k0_m_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
K1Number
*
a_k_m
.
mDesc
.
GetStrides
()[
0
],
a_k_m
.
mDesc
.
GetStrides
()[
1
],
a_k_m
.
mDesc
.
GetStrides
()[
0
]));
const
auto
b_k0_n_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
K1Number
*
b_k_n
.
mDesc
.
GetStrides
()[
0
],
b_k_n
.
mDesc
.
GetStrides
()[
1
],
b_k_n
.
mDesc
.
GetStrides
()[
0
]));
const
auto
c_m_n_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_n_m
.
mDesc
.
GetStrides
()[
1
],
c_n_m
.
mDesc
.
GetStrides
()[
0
]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: M
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: M
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: N
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: N
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
ABType
,
AccType
,
CType
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
c_m_n_grid_desc
),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
ABlockTransferSrcScalarPerVector_M
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
BBlockTransferSrcScalarPerVector_N
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
6
,
CThreadTransferDstScalarPerVector
,
decltype
(
a_k0_m_k1_grid_step_hacks
),
decltype
(
b_k0_n_k1_grid_step_hacks
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
ABType
*>
(
a_k_m_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_n_m_device_buf
.
GetDeviceBuffer
()),
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
a_k0_m_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
a_k0_m_k1_grid_move_slice_window_step_hacks
,
b_k0_n_k1_grid_move_slice_window_step_hacks
,
nrepeat
);
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
M
*
N
*
K
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
// copy result back to host
c_n_m_device_buf
.
FromDevice
(
c_n_m
.
mData
.
data
());
}
host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp
0 → 100644
View file @
211dae82
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
ABType
,
typename
AccType
,
typename
CType
>
void
device_gemm_xdlops_km_nk_mn
(
const
Tensor
<
ABType
>&
a_k_m
,
const
Tensor
<
ABType
>&
b_n_k
,
Tensor
<
CType
>&
c_m_n
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
DeviceMem
a_k_m_device_buf
(
sizeof
(
ABType
)
*
a_k_m
.
mDesc
.
GetElementSpace
());
DeviceMem
b_n_k_device_buf
(
sizeof
(
ABType
)
*
b_n_k
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CType
)
*
c_m_n
.
mDesc
.
GetElementSpace
());
a_k_m_device_buf
.
ToDevice
(
a_k_m
.
mData
.
data
());
b_n_k_device_buf
.
ToDevice
(
b_n_k
.
mData
.
data
());
c_m_n_device_buf
.
ToDevice
(
c_m_n
.
mData
.
data
());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
1
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
256
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
1
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#endif
const
auto
K
=
a_k_m
.
mDesc
.
GetLengths
()[
0
];
const
auto
M
=
a_k_m
.
mDesc
.
GetLengths
()[
1
];
const
auto
N
=
b_n_k
.
mDesc
.
GetLengths
()[
0
];
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
const
auto
a_k0_m_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
K1Number
*
a_k_m
.
mDesc
.
GetStrides
()[
0
],
a_k_m
.
mDesc
.
GetStrides
()[
1
],
a_k_m
.
mDesc
.
GetStrides
()[
0
]));
const
auto
b_k0_n_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
K1Number
*
b_n_k
.
mDesc
.
GetStrides
()[
1
],
b_n_k
.
mDesc
.
GetStrides
()[
0
],
b_n_k
.
mDesc
.
GetStrides
()[
1
]));
const
auto
c_m_n_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_m_n
.
mDesc
.
GetStrides
()[
0
],
c_m_n
.
mDesc
.
GetStrides
()[
1
]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: M
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: M
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: N
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: N
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
ABType
,
AccType
,
CType
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
c_m_n_grid_desc
),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
ABlockTransferSrcScalarPerVector_M
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
BBlockTransferSrcScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
0
,
2
,
4
,
5
,
6
,
1
,
3
,
7
>
,
7
,
CThreadTransferDstScalarPerVector
,
decltype
(
a_k0_m_k1_grid_step_hacks
),
decltype
(
b_k0_n_k1_grid_step_hacks
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
false
,
// CAccessOrderMRepeatNRepeat
true
,
// ABlockLdsExtraM
true
// BBlockLdsExtraN
>
(
static_cast
<
ABType
*>
(
a_k_m_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_n_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
a_k0_m_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
a_k0_m_k1_grid_move_slice_window_step_hacks
,
b_k0_n_k1_grid_move_slice_window_step_hacks
,
nrepeat
);
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
M
*
N
*
K
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
// copy result back to host
c_m_n_device_buf
.
FromDevice
(
c_m_n
.
mData
.
data
());
}
host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp
0 → 100644
View file @
211dae82
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
ABType
,
typename
AccType
,
typename
CType
>
void
device_gemm_xdlops_km_nk_nm
(
const
Tensor
<
ABType
>&
a_k_m
,
const
Tensor
<
ABType
>&
b_n_k
,
Tensor
<
CType
>&
c_n_m
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
DeviceMem
a_k_m_device_buf
(
sizeof
(
ABType
)
*
a_k_m
.
mDesc
.
GetElementSpace
());
DeviceMem
b_n_k_device_buf
(
sizeof
(
ABType
)
*
b_n_k
.
mDesc
.
GetElementSpace
());
DeviceMem
c_n_m_device_buf
(
sizeof
(
CType
)
*
c_n_m
.
mDesc
.
GetElementSpace
());
a_k_m_device_buf
.
ToDevice
(
a_k_m
.
mData
.
data
());
b_n_k_device_buf
.
ToDevice
(
b_n_k
.
mData
.
data
());
c_n_m_device_buf
.
ToDevice
(
c_n_m
.
mData
.
data
());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
256
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#endif
const
auto
K
=
a_k_m
.
mDesc
.
GetLengths
()[
0
];
const
auto
M
=
a_k_m
.
mDesc
.
GetLengths
()[
1
];
const
auto
N
=
b_n_k
.
mDesc
.
GetLengths
()[
0
];
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
const
auto
a_k0_m_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
K1Number
*
a_k_m
.
mDesc
.
GetStrides
()[
0
],
a_k_m
.
mDesc
.
GetStrides
()[
1
],
a_k_m
.
mDesc
.
GetStrides
()[
0
]));
const
auto
b_k0_n_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
K1Number
*
b_n_k
.
mDesc
.
GetStrides
()[
1
],
b_n_k
.
mDesc
.
GetStrides
()[
0
],
b_n_k
.
mDesc
.
GetStrides
()[
1
]));
const
auto
c_m_n_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_n_m
.
mDesc
.
GetStrides
()[
1
],
c_n_m
.
mDesc
.
GetStrides
()[
0
]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: M
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: M
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: N
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: N
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
ABType
,
AccType
,
CType
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
c_m_n_grid_desc
),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
ABlockTransferSrcScalarPerVector_M
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
BBlockTransferSrcScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
6
,
CThreadTransferDstScalarPerVector
,
decltype
(
a_k0_m_k1_grid_step_hacks
),
decltype
(
b_k0_n_k1_grid_step_hacks
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
ABType
*>
(
a_k_m_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_n_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_n_m_device_buf
.
GetDeviceBuffer
()),
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
a_k0_m_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
a_k0_m_k1_grid_move_slice_window_step_hacks
,
b_k0_n_k1_grid_move_slice_window_step_hacks
,
nrepeat
);
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
M
*
N
*
K
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
// copy result back to host
c_n_m_device_buf
.
FromDevice
(
c_n_m
.
mData
.
data
());
}
host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp
0 → 100644
View file @
211dae82
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
ABType
,
typename
AccType
,
typename
CType
>
void
device_gemm_xdlops_mk_kn_mn
(
const
Tensor
<
ABType
>&
a_m_k
,
const
Tensor
<
ABType
>&
b_k_n
,
Tensor
<
CType
>&
c_m_n
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
DeviceMem
a_m_k_device_buf
(
sizeof
(
ABType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
ABType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CType
)
*
c_m_n
.
mDesc
.
GetElementSpace
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
c_m_n_device_buf
.
ToDevice
(
c_m_n
.
mData
.
data
());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
1
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
1
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
256
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
1
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
1
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#endif
const
auto
K
=
a_m_k
.
mDesc
.
GetLengths
()[
1
];
const
auto
M
=
a_m_k
.
mDesc
.
GetLengths
()[
0
];
const
auto
N
=
b_k_n
.
mDesc
.
GetLengths
()[
1
];
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
const
auto
a_k0_m_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
K1Number
*
a_m_k
.
mDesc
.
GetStrides
()[
1
],
a_m_k
.
mDesc
.
GetStrides
()[
0
],
a_m_k
.
mDesc
.
GetStrides
()[
1
]));
const
auto
b_k0_n_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
K1Number
*
b_k_n
.
mDesc
.
GetStrides
()[
0
],
b_k_n
.
mDesc
.
GetStrides
()[
1
],
b_k_n
.
mDesc
.
GetStrides
()[
0
]));
const
auto
c_m_n_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_m_n
.
mDesc
.
GetStrides
()[
0
],
c_m_n
.
mDesc
.
GetStrides
()[
1
]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: M
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: M
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: N
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: N
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
ABType
,
AccType
,
CType
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
c_m_n_grid_desc
),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
ABlockTransferSrcScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
BBlockTransferSrcScalarPerVector_N
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
0
,
2
,
4
,
5
,
6
,
1
,
3
,
7
>
,
7
,
CThreadTransferDstScalarPerVector
,
decltype
(
a_k0_m_k1_grid_step_hacks
),
decltype
(
b_k0_n_k1_grid_step_hacks
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
false
,
// CAccessOrderMRepeatNRepeat
true
,
// ABlockLdsExtraM
true
// BBlockLdsExtraN
>
(
static_cast
<
ABType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
a_k0_m_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
a_k0_m_k1_grid_move_slice_window_step_hacks
,
b_k0_n_k1_grid_move_slice_window_step_hacks
,
nrepeat
);
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
M
*
N
*
K
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
// copy result back to host
c_m_n_device_buf
.
FromDevice
(
c_m_n
.
mData
.
data
());
}
host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp
0 → 100644
View file @
211dae82
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
ABType
,
typename
AccType
,
typename
CType
>
void
device_gemm_xdlops_mk_kn_nm
(
const
Tensor
<
ABType
>&
a_m_k
,
const
Tensor
<
ABType
>&
b_k_n
,
Tensor
<
CType
>&
c_n_m
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
DeviceMem
a_m_k_device_buf
(
sizeof
(
ABType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
ABType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_n_m_device_buf
(
sizeof
(
CType
)
*
c_n_m
.
mDesc
.
GetElementSpace
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
c_n_m_device_buf
.
ToDevice
(
c_n_m
.
mData
.
data
());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
256
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#endif
const
auto
K
=
a_m_k
.
mDesc
.
GetLengths
()[
1
];
const
auto
M
=
a_m_k
.
mDesc
.
GetLengths
()[
0
];
const
auto
N
=
b_k_n
.
mDesc
.
GetLengths
()[
1
];
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
const
auto
a_k0_m_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
K1Number
*
a_m_k
.
mDesc
.
GetStrides
()[
1
],
a_m_k
.
mDesc
.
GetStrides
()[
0
],
a_m_k
.
mDesc
.
GetStrides
()[
1
]));
const
auto
b_k0_n_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
K1Number
*
b_k_n
.
mDesc
.
GetStrides
()[
0
],
b_k_n
.
mDesc
.
GetStrides
()[
1
],
b_k_n
.
mDesc
.
GetStrides
()[
0
]));
const
auto
c_m_n_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_n_m
.
mDesc
.
GetStrides
()[
1
],
c_n_m
.
mDesc
.
GetStrides
()[
0
]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: M
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: M
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: N
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: N
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
ABType
,
AccType
,
CType
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
c_m_n_grid_desc
),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
ABlockTransferSrcScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
BBlockTransferSrcScalarPerVector_N
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
6
,
CThreadTransferDstScalarPerVector
,
decltype
(
a_k0_m_k1_grid_step_hacks
),
decltype
(
b_k0_n_k1_grid_step_hacks
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
ABType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_n_m_device_buf
.
GetDeviceBuffer
()),
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
a_k0_m_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
a_k0_m_k1_grid_move_slice_window_step_hacks
,
b_k0_n_k1_grid_move_slice_window_step_hacks
,
nrepeat
);
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
M
*
N
*
K
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
// copy result back to host
c_n_m_device_buf
.
FromDevice
(
c_n_m
.
mData
.
data
());
}
host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp
0 → 100644
View file @
211dae82
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
ABType
,
typename
AccType
,
typename
CType
>
void
device_gemm_xdlops_mk_nk_mn
(
const
Tensor
<
ABType
>&
a_m_k
,
const
Tensor
<
ABType
>&
b_n_k
,
Tensor
<
CType
>&
c_m_n
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
DeviceMem
a_m_k_device_buf
(
sizeof
(
ABType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_n_k_device_buf
(
sizeof
(
ABType
)
*
b_n_k
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CType
)
*
c_m_n
.
mDesc
.
GetElementSpace
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_n_k_device_buf
.
ToDevice
(
b_n_k
.
mData
.
data
());
c_m_n_device_buf
.
ToDevice
(
c_m_n
.
mData
.
data
());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
1
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
256
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
1
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#endif
const
auto
K
=
a_m_k
.
mDesc
.
GetLengths
()[
1
];
const
auto
M
=
a_m_k
.
mDesc
.
GetLengths
()[
0
];
const
auto
N
=
b_n_k
.
mDesc
.
GetLengths
()[
0
];
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
#if 1
// non-padded GEMM
const
auto
a_k0_m_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
K1Number
*
a_m_k
.
mDesc
.
GetStrides
()[
1
],
a_m_k
.
mDesc
.
GetStrides
()[
0
],
a_m_k
.
mDesc
.
GetStrides
()[
1
]));
const
auto
b_k0_n_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
K1Number
*
b_n_k
.
mDesc
.
GetStrides
()[
1
],
b_n_k
.
mDesc
.
GetStrides
()[
0
],
b_n_k
.
mDesc
.
GetStrides
()[
1
]));
const
auto
c_m_n_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_m_n
.
mDesc
.
GetStrides
()[
0
],
c_m_n
.
mDesc
.
GetStrides
()[
1
]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: M
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: M
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: N
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: N
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
#else
// padded GEMM
const
auto
a_k0_m_k1_grid_desc_tmp
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
K1Number
*
a_m_k
.
mDesc
.
GetStrides
()[
1
],
a_m_k
.
mDesc
.
GetStrides
()[
0
],
a_m_k
.
mDesc
.
GetStrides
()[
1
]));
const
auto
MRightPad
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
)
*
MPerBlock
-
M
;
const
auto
a_k0_m_k1_grid_desc
=
transform_tensor_descriptor
(
a_k0_m_k1_grid_desc_tmp
,
make_tuple
(
make_pass_through_transform
(
K0
),
make_right_pad_transform
(
M
,
MRightPad
),
make_pass_through_transform
(
K1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
b_k0_n_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
K1Number
*
b_n_k
.
mDesc
.
GetStrides
()[
1
],
b_n_k
.
mDesc
.
GetStrides
()[
0
],
b_n_k
.
mDesc
.
GetStrides
()[
1
]));
const
auto
c_m_n_grid_desc_tmp
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_m_n
.
mDesc
.
GetStrides
()[
0
],
c_m_n
.
mDesc
.
GetStrides
()[
1
]));
const
auto
c_m_n_grid_desc
=
transform_tensor_descriptor
(
c_m_n_grid_desc_tmp
,
make_tuple
(
make_right_pad_transform
(
M
,
MRightPad
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
>
{},
// 0+: K0
Sequence
<
0
,
0
,
0
,
0
>
{},
// 1+: M
Sequence
<
0
,
0
,
0
,
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
>
{},
// 0-: K0
Sequence
<
0
,
0
,
0
,
0
>
{},
// 1-: M
Sequence
<
0
,
0
,
0
,
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: N
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: N
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
#endif
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
ABType
,
AccType
,
CType
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
c_m_n_grid_desc
),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
ABlockTransferSrcScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
BBlockTransferSrcScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
0
,
2
,
4
,
5
,
6
,
1
,
3
,
7
>
,
7
,
CThreadTransferDstScalarPerVector
,
decltype
(
a_k0_m_k1_grid_step_hacks
),
decltype
(
b_k0_n_k1_grid_step_hacks
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
false
,
// CAccessOrderMRepeatNRepeat
true
,
// ABlockLdsExtraM
true
// BBlockLdsExtraN
>
(
static_cast
<
ABType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_n_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
a_k0_m_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
a_k0_m_k1_grid_move_slice_window_step_hacks
,
b_k0_n_k1_grid_move_slice_window_step_hacks
,
nrepeat
);
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
M
*
N
*
K
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
// copy result back to host
c_m_n_device_buf
.
FromDevice
(
c_m_n
.
mData
.
data
());
}
host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp
0 → 100644
View file @
211dae82
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
ABType
,
typename
AccType
,
typename
CType
>
void
device_gemm_xdlops_mk_nk_nm
(
const
Tensor
<
ABType
>&
a_m_k
,
const
Tensor
<
ABType
>&
b_n_k
,
Tensor
<
CType
>&
c_n_m
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
DeviceMem
a_m_k_device_buf
(
sizeof
(
ABType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_n_k_device_buf
(
sizeof
(
ABType
)
*
b_n_k
.
mDesc
.
GetElementSpace
());
DeviceMem
c_n_m_device_buf
(
sizeof
(
CType
)
*
c_n_m
.
mDesc
.
GetElementSpace
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_n_k_device_buf
.
ToDevice
(
b_n_k
.
mData
.
data
());
c_n_m_device_buf
.
ToDevice
(
c_n_m
.
mData
.
data
());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 0
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
256
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#endif
const
auto
K
=
a_m_k
.
mDesc
.
GetLengths
()[
1
];
const
auto
M
=
a_m_k
.
mDesc
.
GetLengths
()[
0
];
const
auto
N
=
b_n_k
.
mDesc
.
GetLengths
()[
0
];
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
const
auto
a_k0_m_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
K1Number
*
a_m_k
.
mDesc
.
GetStrides
()[
1
],
a_m_k
.
mDesc
.
GetStrides
()[
0
],
a_m_k
.
mDesc
.
GetStrides
()[
1
]));
const
auto
b_k0_n_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
K1Number
*
b_n_k
.
mDesc
.
GetStrides
()[
1
],
b_n_k
.
mDesc
.
GetStrides
()[
0
],
b_n_k
.
mDesc
.
GetStrides
()[
1
]));
const
auto
c_m_n_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_n_m
.
mDesc
.
GetStrides
()[
1
],
c_n_m
.
mDesc
.
GetStrides
()[
0
]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: M
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: M
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: N
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: N
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
ABType
,
AccType
,
CType
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
c_m_n_grid_desc
),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
ABlockTransferSrcScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
BBlockTransferSrcScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
6
,
CThreadTransferDstScalarPerVector
,
decltype
(
a_k0_m_k1_grid_step_hacks
),
decltype
(
b_k0_n_k1_grid_step_hacks
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
ABType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_n_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_n_m_device_buf
.
GetDeviceBuffer
()),
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
a_k0_m_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
a_k0_m_k1_grid_move_slice_window_step_hacks
,
b_k0_n_k1_grid_move_slice_window_step_hacks
,
nrepeat
);
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
M
*
N
*
K
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
// copy result back to host
c_n_m_device_buf
.
FromDevice
(
c_n_m
.
mData
.
data
());
}
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
View file @
211dae82
#ifndef DRIVER_GEMM_XDLOPS_V2R3
#define DRIVER_GEMM_XDLOPS_V2R3
#ifndef DRIVER_GEMM_XDLOPS_V2R3
_HPP
#define DRIVER_GEMM_XDLOPS_V2R3
_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
...
...
@@ -17,8 +17,8 @@ template <ck::index_t BlockSize,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
MPer
Wave
,
ck
::
index_t
NPer
Wave
,
ck
::
index_t
MPer
XDL
,
ck
::
index_t
NPer
XDL
,
ck
::
index_t
K1
,
ck
::
index_t
MRepeat
,
ck
::
index_t
NRepeat
,
...
...
@@ -46,13 +46,17 @@ template <ck::index_t BlockSize,
typename
CGridStepHacks
,
typename
AGridMoveSliceWindowStepHacks
,
typename
BGridMoveSliceWindowStepHacks
,
bool
CAccessOrderMRepeatNRepeat
>
bool
CAccessOrderMRepeatNRepeat
,
bool
ABlockLdsAddExtraM
,
bool
BBlockLdsAddExtraN
>
__host__
float
driver_gemm_xdlops_v2r3
(
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
const
AK0MK1GridDesc
&
a_k0_m_k1_grid_desc
,
const
BK0NK1GridDesc
&
b_k0_n_k1_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
ck
::
index_t
M01
,
ck
::
index_t
N01
,
AGridStepHacks
,
BGridStepHacks
,
CGridStepHacks
,
...
...
@@ -79,8 +83,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPer
Wave
,
NPer
Wave
,
MPer
XDL
,
NPer
XDL
,
K1
,
MRepeat
,
NRepeat
,
...
...
@@ -108,7 +112,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
CGridStepHacks
,
AGridMoveSliceWindowStepHacks
,
BGridMoveSliceWindowStepHacks
,
CAccessOrderMRepeatNRepeat
>
;
CAccessOrderMRepeatNRepeat
,
ABlockLdsAddExtraM
,
BBlockLdsAddExtraN
>
;
{
std
::
cout
<<
"a_k0_m_k1_grid_desc{"
<<
a_k0_m_k1_grid_desc
.
GetLength
(
I0
)
<<
", "
...
...
@@ -123,32 +129,44 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
<<
c_m_n_grid_desc
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
if
(
!
GridwiseGemm
::
CheckValidity
(
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
))
if
(
!
GridwiseGemm
::
CheckValidity
(
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
M01
,
N01
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"
);
}
const
auto
c_m0_m1_m2_n_grid_desc
=
GridwiseGemm
::
MakeCM0M1M2NGridDescriptor
(
c_m_n_grid_desc
);
const
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
=
GridwiseGemm
::
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
c_m_n_grid_desc
);
using
CM0
M1M2N
GridDesc
=
decltype
(
c_m0_
m1_m2_n
_grid_desc
);
using
CM0
N0M1N1M2M3M4N2
GridDesc
=
decltype
(
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc
);
const
auto
c_block_cluster_adaptor
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_m_n_grid_desc
);
const
auto
c_block_cluster_adaptor
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_m_n_grid_desc
,
M01
,
N01
);
using
CBlockClusterAdaptor
=
decltype
(
c_block_cluster_adaptor
);
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
c_m_n_grid_desc
);
const
auto
K0
=
a_k0_m_k1_grid_desc
.
GetLength
(
I0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
float
ave_time
=
0
;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AK0MK1GridDesc
>
,
remove_reference_t
<
BK0NK1GridDesc
>
,
remove_reference_t
<
CM0M1M2NGridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>>
;
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>
,
true
>
;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
...
...
@@ -158,21 +176,56 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
p_c_grid
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m0_
m1_m2_n
_grid_desc
,
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc
,
c_block_cluster_adaptor
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AK0MK1GridDesc
>
,
remove_reference_t
<
BK0NK1GridDesc
>
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem
a_k0_m_k1_grid_desc_dev_buf
(
sizeof
(
AK0MK1GridDesc
));
DeviceMem
b_k0_n_k1_grid_desc_dev_buf
(
sizeof
(
BK0NK1GridDesc
));
DeviceMem
c_m0_
m1_m2_n
_grid_desc_dev_buf
(
sizeof
(
CM0
M1M2N
GridDesc
));
DeviceMem
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc_dev_buf
(
sizeof
(
CM0
N0M1N1M2M3M4N2
GridDesc
));
DeviceMem
c_block_cluster_adaptor_dev_buf
(
sizeof
(
CBlockClusterAdaptor
));
a_k0_m_k1_grid_desc_dev_buf
.
ToDevice
(
&
a_k0_m_k1_grid_desc
);
b_k0_n_k1_grid_desc_dev_buf
.
ToDevice
(
&
b_k0_n_k1_grid_desc
);
c_m0_
m1_m2_n
_grid_desc_dev_buf
.
ToDevice
(
&
c_m0_
m1_m2_n
_grid_desc
);
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc_dev_buf
.
ToDevice
(
&
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc
);
c_block_cluster_adaptor_dev_buf
.
ToDevice
(
&
c_block_cluster_adaptor
);
float
ave_time
=
launch_and_time_kernel
(
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AK0MK1GridDesc
>
,
remove_reference_t
<
BK0NK1GridDesc
>
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
...
...
@@ -183,8 +236,39 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_m1_m2_n_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
cast_pointer_to_constant_address_space
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AK0MK1GridDesc
>
,
remove_reference_t
<
BK0NK1GridDesc
>
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
}
#endif
return
ave_time
;
}
...
...
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
0 → 100644
View file @
211dae82
#ifndef DRIVER_GEMM_XDLOPS_V2R4
#define DRIVER_GEMM_XDLOPS_V2R4
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r4.hpp"
template
<
ck
::
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
ck
::
InMemoryDataOperationEnum_t
CGlobalMemoryDataOperation
,
typename
ABK0MK1GridDesc
,
typename
BBK0NK1GridDesc
,
typename
CMNGridDesc
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
K1
,
ck
::
index_t
MRepeat
,
ck
::
index_t
NRepeat
,
typename
ABlockTransferThreadSliceLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
typename
BBlockTransferThreadSliceLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
typename
CThreadTransferSrcDstAccessOrder
,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferDstScalarPerVector
,
typename
AGridStepHacks
,
typename
BGridStepHacks
,
typename
CGridStepHacks
,
typename
AGridMoveSliceWindowStepHacks
,
typename
BGridMoveSliceWindowStepHacks
,
bool
CAccessOrderMRepeatNRepeat
,
bool
ABlockLdsAddExtraM
,
bool
BBlockLdsAddExtraN
>
__host__
float
driver_gemm_xdlops_v2r4
(
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
const
ABK0MK1GridDesc
&
a_b_k0_m_k1_grid_desc
,
const
BBK0NK1GridDesc
&
b_b_k0_n_k1_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
ck
::
index_t
M01
,
ck
::
index_t
N01
,
AGridStepHacks
,
BGridStepHacks
,
CGridStepHacks
,
AGridMoveSliceWindowStepHacks
,
BGridMoveSliceWindowStepHacks
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
<
BlockSize
,
FloatAB
,
FloatAcc
,
FloatC
,
CGlobalMemoryDataOperation
,
ABK0MK1GridDesc
,
BBK0NK1GridDesc
,
CMNGridDesc
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
AThreadTransferSrcResetCoordinateAfterRun
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
BThreadTransferSrcResetCoordinateAfterRun
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
AGridStepHacks
,
BGridStepHacks
,
CGridStepHacks
,
AGridMoveSliceWindowStepHacks
,
BGridMoveSliceWindowStepHacks
,
CAccessOrderMRepeatNRepeat
,
ABlockLdsAddExtraM
,
BBlockLdsAddExtraN
>
;
{
std
::
cout
<<
"a_b_k0_m_k1_grid_desc{"
<<
a_b_k0_m_k1_grid_desc
.
GetLength
(
I0
)
<<
", "
<<
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
)
<<
", "
<<
a_b_k0_m_k1_grid_desc
.
GetLength
(
I2
)
<<
", "
<<
a_b_k0_m_k1_grid_desc
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"b_b_k0_n_k1_grid_desc{"
<<
b_b_k0_n_k1_grid_desc
.
GetLength
(
I0
)
<<
", "
<<
b_b_k0_n_k1_grid_desc
.
GetLength
(
I1
)
<<
", "
<<
b_b_k0_n_k1_grid_desc
.
GetLength
(
I2
)
<<
", "
<<
b_b_k0_n_k1_grid_desc
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"c_m_n_grid_desc{ "
<<
c_m_n_grid_desc
.
GetLength
(
I0
)
<<
", "
<<
c_m_n_grid_desc
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
if
(
!
GridwiseGemm
::
CheckValidity
(
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
M01
,
N01
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r4 has invalid setting"
);
}
const
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
=
GridwiseGemm
::
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
c_m_n_grid_desc
);
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
const
auto
KBatch
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I0
);
const
auto
c_block_cluster_adaptor
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_m_n_grid_desc
,
M01
,
N01
,
KBatch
);
using
CBlockClusterAdaptor
=
decltype
(
c_block_cluster_adaptor
);
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
c_m_n_grid_desc
,
KBatch
);
{
std
::
cout
<<
"gridSize : "
<<
grid_size
<<
std
::
endl
;
}
const
auto
kernel
=
kernel_gemm_xdlops_v2r4
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
ABK0MK1GridDesc
>
,
remove_reference_t
<
BBK0NK1GridDesc
>
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>>
;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem
a_b_k0_m_k1_grid_desc_dev_buf
(
sizeof
(
ABK0MK1GridDesc
));
DeviceMem
b_b_k0_n_k1_grid_desc_dev_buf
(
sizeof
(
BBK0NK1GridDesc
));
DeviceMem
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
(
sizeof
(
CM0N0M1N1M2M3M4N2GridDesc
));
DeviceMem
c_block_cluster_adaptor_dev_buf
(
sizeof
(
CBlockClusterAdaptor
));
a_b_k0_m_k1_grid_desc_dev_buf
.
ToDevice
(
&
a_b_k0_m_k1_grid_desc
);
b_b_k0_n_k1_grid_desc_dev_buf
.
ToDevice
(
&
b_b_k0_n_k1_grid_desc
);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
ToDevice
(
&
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
c_block_cluster_adaptor_dev_buf
.
ToDevice
(
&
c_block_cluster_adaptor
);
float
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_b_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_b_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
#endif
return
ave_time
;
}
#endif
host/driver_offline/src/conv_bwd_driver_offline.cpp
View file @
211dae82
...
...
@@ -5,6 +5,7 @@
#include <stdlib.h>
#include <half.hpp>
#include "config.hpp"
#include "debug.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
...
...
@@ -14,15 +15,16 @@
#include "device_tensor.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp"
#define USE_MODE 1
#define USE_CONV_BWD_V4R1_XDL_NHWC
1
#define USE_CONV_BWD_V4R1_XDL_NHWC
0
#define USE_CONV_BWD_V4R1R2_XDL_NHWC 1
enum
ConvBackwardDataAlgo
{
V4R1XDLNHWC
,
V4R1R2XDLNHWC
,
V4R1XDLNHWC
,
// 0
V4R1R2XDLNHWC
,
// 1
};
int
main
(
int
argc
,
char
*
argv
[])
...
...
@@ -41,7 +43,7 @@ int main(int argc, char* argv[])
// dynamic mode
if
(
argc
!=
22
)
{
printf
(
"arg1 to
5
: layout, algo, do_verification, init_method, do_log, nrepeat
\n
"
);
printf
(
"arg1 to
6
: layout, algo, do_verification, init_method, do_log, nrepeat
\n
"
);
printf
(
"rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx
\n
"
);
exit
(
1
);
}
...
...
@@ -79,7 +81,7 @@ int main(int argc, char* argv[])
// static mode
if
(
argc
<
7
)
{
printf
(
"arg1 to
5
: layout, algo, do_verification, init_method, do_log, nrepeat
\n
"
);
printf
(
"arg1 to
6
: layout, algo, do_verification, init_method, do_log, nrepeat
\n
"
);
exit
(
1
);
}
...
...
@@ -90,28 +92,28 @@ int main(int argc, char* argv[])
const
bool
do_log
=
std
::
stoi
(
argv
[
5
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
6
]);
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
192
;
constexpr
index_t
Hi
=
71
;
constexpr
index_t
Wi
=
71
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
const
index_t
conv_stride_h
=
2
;
const
index_t
conv_stride_w
=
2
;
const
index_t
conv_dilation_h
=
1
;
const
index_t
conv_dilation_w
=
1
;
const
index_t
in_left_pad_h
=
1
;
const
index_t
in_left_pad_w
=
1
;
const
index_t
in_right_pad_h
=
1
;
const
index_t
in_right_pad_w
=
1
;
const
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
const
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
const
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
YEff
)
/
conv_stride_h
+
1
;
const
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
1
;
constexpr
auto
N
=
Number
<
128
>
{}
;
constexpr
auto
C
=
Number
<
192
>
{}
;
constexpr
auto
Hi
=
Number
<
71
>
{}
;
constexpr
auto
Wi
=
Number
<
71
>
{}
;
constexpr
auto
K
=
Number
<
256
>
{}
;
constexpr
auto
Y
=
Number
<
3
>
{}
;
constexpr
auto
X
=
Number
<
3
>
{}
;
const
expr
auto
conv_stride_h
=
I
2
;
const
expr
auto
conv_stride_w
=
I
2
;
const
expr
auto
conv_dilation_h
=
I
1
;
const
expr
auto
conv_dilation_w
=
I
1
;
const
expr
auto
in_left_pad_h
=
I
1
;
const
expr
auto
in_left_pad_w
=
I
1
;
const
expr
auto
in_right_pad_h
=
I
1
;
const
expr
auto
in_right_pad_w
=
I
1
;
const
expr
auto
YEff
=
(
Y
-
I
1
)
*
conv_dilation_h
+
I
1
;
const
expr
auto
XEff
=
(
X
-
I
1
)
*
conv_dilation_w
+
I
1
;
const
expr
auto
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
YEff
)
/
conv_stride_h
+
I
1
;
const
expr
auto
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
I
1
;
#endif
#if 0
...
...
@@ -280,6 +282,27 @@ int main(int argc, char* argv[])
const
auto
tmp
=
f_make_for_device_nhwc
();
if
(
Y
==
1
&&
X
==
1
&&
in_left_pad_h
==
0
&&
in_left_pad_w
==
0
&&
in_right_pad_h
==
0
&&
in_right_pad_w
==
0
)
{
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1
<
in_data_t
,
acc_data_t
,
out_data_t
>
(
tmp
[
I0
],
tmp
[
I1
],
tmp
[
I2
],
tmp
[
I3
],
tmp
[
I4
],
tmp
[
I5
],
tmp
[
I6
],
in_device
,
wei
,
out
,
nrepeat
);
}
else
{
#if 1
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
<
in_data_t
,
acc_data_t
,
out_data_t
>
(
...
...
@@ -294,6 +317,8 @@ int main(int argc, char* argv[])
wei
,
out
,
nrepeat
);
#endif
}
}
#endif
...
...
host/driver_offline/src/conv_fwd_driver_offline.cpp
View file @
211dae82
...
...
@@ -5,6 +5,7 @@
#include <stdlib.h>
#include <half.hpp>
#include "config.hpp"
#include "debug.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
...
...
@@ -19,13 +20,13 @@
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_MODE 1
#define USE_CONV_FWD_V4R4_NCHW
1
#define USE_CONV_FWD_V4R4R2_NHWC
1
#define USE_
DYNAMIC_
MODE 1
#define USE_CONV_FWD_V4R4_NCHW
0
#define USE_CONV_FWD_V4R4R2_NHWC
0
#define USE_CONV_FWD_V6R1_NCHW 0
#define USE_CONV_FWD_V5R1_NCHW 0
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
#define USE_CONV_FWD_V4R4R4_XDL_NHWC
0
#define USE_CONV_FWD_V4R4R4_XDL_NHWC
1
enum
ConvForwardAlgo
{
...
...
@@ -49,11 +50,11 @@ int main(int argc, char* argv[])
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
#if USE_MODE
#if USE_
DYNAMIC_
MODE
// dynamic mode
if
(
argc
!=
22
)
{
printf
(
"arg1 to
5
: layout, algo, do_verification, init_method, do_log, nrepeat
\n
"
);
printf
(
"arg1 to
6
: layout, algo, do_verification, init_method, do_log, nrepeat
\n
"
);
printf
(
"rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx
\n
"
);
exit
(
1
);
}
...
...
@@ -91,7 +92,7 @@ int main(int argc, char* argv[])
// static mode
if
(
argc
<
7
)
{
printf
(
"arg1 to
5
: layout, algo, do_verification, init_method, do_log, nrepeat
\n
"
);
printf
(
"arg1 to
6
: layout, algo, do_verification, init_method, do_log, nrepeat
\n
"
);
exit
(
1
);
}
...
...
@@ -102,31 +103,31 @@ int main(int argc, char* argv[])
const
bool
do_log
=
std
::
stoi
(
argv
[
5
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
6
]);
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
192
;
constexpr
index_t
Hi
=
71
;
constexpr
index_t
Wi
=
71
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
const
index_t
conv_stride_h
=
2
;
const
index_t
conv_stride_w
=
2
;
const
index_t
conv_dilation_h
=
1
;
const
index_t
conv_dilation_w
=
1
;
const
index_t
in_left_pad_h
=
1
;
const
index_t
in_left_pad_w
=
1
;
const
index_t
in_right_pad_h
=
1
;
const
index_t
in_right_pad_w
=
1
;
const
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
const
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
const
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
YEff
)
/
conv_stride_h
+
1
;
const
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
1
;
constexpr
auto
N
=
Number
<
128
>
{}
;
constexpr
auto
C
=
Number
<
192
>
{}
;
constexpr
auto
Hi
=
Number
<
71
>
{}
;
constexpr
auto
Wi
=
Number
<
71
>
{}
;
constexpr
auto
K
=
Number
<
256
>
{}
;
constexpr
auto
Y
=
Number
<
3
>
{}
;
constexpr
auto
X
=
Number
<
3
>
{}
;
const
expr
auto
conv_stride_h
=
I
2
;
const
expr
auto
conv_stride_w
=
I
2
;
const
expr
auto
conv_dilation_h
=
I
1
;
const
expr
auto
conv_dilation_w
=
I
1
;
const
expr
auto
in_left_pad_h
=
I
1
;
const
expr
auto
in_left_pad_w
=
I
1
;
const
expr
auto
in_right_pad_h
=
I
1
;
const
expr
auto
in_right_pad_w
=
I
1
;
const
expr
auto
YEff
=
(
Y
-
I
1
)
*
conv_dilation_h
+
I
1
;
const
expr
auto
XEff
=
(
X
-
I
1
)
*
conv_dilation_w
+
I
1
;
const
expr
auto
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
YEff
)
/
conv_stride_h
+
I
1
;
const
expr
auto
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
I
1
;
#endif
#if
1
#if
0
using in_data_t = float;
using acc_data_t = float;
using out_data_t = float;
...
...
@@ -228,7 +229,6 @@ int main(int argc, char* argv[])
}
auto
f_make_for_device_nchw
=
[
&
]()
{
#if USE_MODE
const
auto
in_lengths_dev
=
make_tuple
(
N
,
C
,
Hi
,
Wi
);
const
auto
wei_lengths_dev
=
make_tuple
(
K
,
C
,
Y
,
X
);
const
auto
out_lengths_dev
=
make_tuple
(
N
,
K
,
Ho
,
Wo
);
...
...
@@ -236,19 +236,6 @@ int main(int argc, char* argv[])
const
auto
conv_dilations_dev
=
make_tuple
(
conv_dilation_h
,
conv_dilation_w
);
const
auto
in_left_pads_dev
=
make_tuple
(
in_left_pad_h
,
in_left_pad_w
);
const
auto
in_right_pads_dev
=
make_tuple
(
in_right_pad_h
,
in_right_pad_w
);
#else
const
auto
in_lengths_dev
=
make_tuple
(
Number
<
N
>
{},
Number
<
C
>
{},
Number
<
Hi
>
{},
Number
<
Wi
>
{});
const
auto
wei_lengths_dev
=
make_tuple
(
Number
<
K
>
{},
Number
<
C
>
{},
Number
<
Y
>
{},
Number
<
X
>
{});
const
auto
out_lengths_dev
=
make_tuple
(
Number
<
N
>
{},
Number
<
K
>
{},
Number
<
Ho
>
{},
Number
<
Wo
>
{});
const
auto
conv_strides_dev
=
make_tuple
(
Number
<
conv_stride_h
>
{},
Number
<
conv_stride_w
>
{});
const
auto
conv_dilations_dev
=
make_tuple
(
Number
<
conv_dilation_h
>
{},
Number
<
conv_dilation_w
>
{});
const
auto
in_left_pads_dev
=
make_tuple
(
Number
<
in_left_pad_h
>
{},
Number
<
in_left_pad_w
>
{});
const
auto
in_right_pads_dev
=
make_tuple
(
Number
<
in_right_pad_h
>
{},
Number
<
in_right_pad_w
>
{});
#endif
return
make_tuple
(
in_lengths_dev
,
wei_lengths_dev
,
...
...
@@ -260,7 +247,6 @@ int main(int argc, char* argv[])
};
auto
f_make_for_device_nhwc
=
[
&
]()
{
#if USE_MODE
const
auto
in_lengths_dev
=
make_tuple
(
N
,
Hi
,
Wi
,
C
);
const
auto
wei_lengths_dev
=
make_tuple
(
K
,
Y
,
X
,
C
);
const
auto
out_lengths_dev
=
make_tuple
(
N
,
Ho
,
Wo
,
K
);
...
...
@@ -268,19 +254,6 @@ int main(int argc, char* argv[])
const
auto
conv_dilations_dev
=
make_tuple
(
conv_dilation_h
,
conv_dilation_w
);
const
auto
in_left_pads_dev
=
make_tuple
(
in_left_pad_h
,
in_left_pad_w
);
const
auto
in_right_pads_dev
=
make_tuple
(
in_right_pad_h
,
in_right_pad_w
);
#else
const
auto
in_lengths_dev
=
make_tuple
(
Number
<
N
>
{},
Number
<
Hi
>
{},
Number
<
Wi
>
{},
Number
<
C
>
{});
const
auto
wei_lengths_dev
=
make_tuple
(
Number
<
K
>
{},
Number
<
Y
>
{},
Number
<
X
>
{},
Number
<
C
>
{});
const
auto
out_lengths_dev
=
make_tuple
(
Number
<
N
>
{},
Number
<
Ho
>
{},
Number
<
Wo
>
{},
Number
<
K
>
{});
const
auto
conv_strides_dev
=
make_tuple
(
Number
<
conv_stride_h
>
{},
Number
<
conv_stride_w
>
{});
const
auto
conv_dilations_dev
=
make_tuple
(
Number
<
conv_dilation_h
>
{},
Number
<
conv_dilation_w
>
{});
const
auto
in_left_pads_dev
=
make_tuple
(
Number
<
in_left_pad_h
>
{},
Number
<
in_left_pad_w
>
{});
const
auto
in_right_pads_dev
=
make_tuple
(
Number
<
in_right_pad_h
>
{},
Number
<
in_right_pad_w
>
{});
#endif
return
make_tuple
(
in_lengths_dev
,
wei_lengths_dev
,
...
...
host/driver_offline/src/conv_wrw_driver_offline.cpp
0 → 100644
View file @
211dae82
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "config.hpp"
#include "debug.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "conv_common.hpp"
#include "host_conv_bwd_weight.hpp"
#include "device_tensor.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1
#define USE_CONV_WRW_V4R4R2_XDL_NCHW 0
#define USE_CONV_WRW_V4R4R4_XDL_NHWC 0
#define USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW 0
#define USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC 0
#define USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC 1
enum
ConvBackwardWeightAlgo
{
V4R4R2XDLNCHW
,
// 0
V4R4R4XDLNHWC
,
// 1
V4R4R2XDLATOMICNCHW
,
// 2
V4R4R4XDLATOMICNHWC
,
// 3
V4R4R5XDLATOMICNHWC
,
// 4
};
int
main
(
int
argc
,
char
*
argv
[])
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
#if USE_DYNAMIC_MODE
// dynamic mode
if
(
argc
!=
23
)
{
printf
(
"arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat
\n
"
);
printf
(
"rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx
\n
"
);
printf
(
"additional: desired_grid_size
\n
"
);
exit
(
1
);
}
const
ConvTensorLayout
layout
=
static_cast
<
ConvTensorLayout
>
(
std
::
stoi
(
argv
[
1
]));
const
ConvBackwardWeightAlgo
algo
=
static_cast
<
ConvBackwardWeightAlgo
>
(
std
::
stoi
(
argv
[
2
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
3
]);
const
int
init_method
=
std
::
stoi
(
argv
[
4
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
5
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
6
]);
const
index_t
N
=
std
::
stoi
(
argv
[
7
]);
const
index_t
K
=
std
::
stoi
(
argv
[
8
]);
const
index_t
C
=
std
::
stoi
(
argv
[
9
]);
const
index_t
Y
=
std
::
stoi
(
argv
[
10
]);
const
index_t
X
=
std
::
stoi
(
argv
[
11
]);
const
index_t
Hi
=
std
::
stoi
(
argv
[
12
]);
const
index_t
Wi
=
std
::
stoi
(
argv
[
13
]);
const
index_t
conv_stride_h
=
std
::
stoi
(
argv
[
14
]);
const
index_t
conv_stride_w
=
std
::
stoi
(
argv
[
15
]);
const
index_t
conv_dilation_h
=
std
::
stoi
(
argv
[
16
]);
const
index_t
conv_dilation_w
=
std
::
stoi
(
argv
[
17
]);
const
index_t
in_left_pad_h
=
std
::
stoi
(
argv
[
18
]);
const
index_t
in_left_pad_w
=
std
::
stoi
(
argv
[
19
]);
const
index_t
in_right_pad_h
=
std
::
stoi
(
argv
[
20
]);
const
index_t
in_right_pad_w
=
std
::
stoi
(
argv
[
21
]);
const
index_t
desired_grid_size
=
std
::
stoi
(
argv
[
22
]);
const
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
const
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
const
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
YEff
)
/
conv_stride_h
+
1
;
const
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
1
;
#else
// static mode
if
(
argc
<
7
)
{
printf
(
"arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat
\n
"
);
exit
(
1
);
}
const
ConvTensorLayout
layout
=
static_cast
<
ConvTensorLayout
>
(
std
::
stoi
(
argv
[
1
]));
const
ConvBackwardWeightAlgo
algo
=
static_cast
<
ConvBackwardWeightAlgo
>
(
std
::
stoi
(
argv
[
2
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
3
]);
const
int
init_method
=
std
::
stoi
(
argv
[
4
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
5
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
6
]);
constexpr
auto
N
=
Number
<
128
>
{};
constexpr
auto
C
=
Number
<
128
>
{};
constexpr
auto
Hi
=
Number
<
14
>
{};
constexpr
auto
Wi
=
Number
<
14
>
{};
constexpr
auto
K
=
Number
<
256
>
{};
constexpr
auto
Y
=
Number
<
3
>
{};
constexpr
auto
X
=
Number
<
3
>
{};
constexpr
auto
conv_stride_h
=
I1
;
constexpr
auto
conv_stride_w
=
I1
;
constexpr
auto
conv_dilation_h
=
I1
;
constexpr
auto
conv_dilation_w
=
I1
;
constexpr
auto
in_left_pad_h
=
I1
;
constexpr
auto
in_left_pad_w
=
I1
;
constexpr
auto
in_right_pad_h
=
I1
;
constexpr
auto
in_right_pad_w
=
I1
;
constexpr
auto
YEff
=
(
Y
-
I1
)
*
conv_dilation_h
+
I1
;
constexpr
auto
XEff
=
(
X
-
I1
)
*
conv_dilation_w
+
I1
;
constexpr
auto
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
YEff
)
/
conv_stride_h
+
I1
;
constexpr
auto
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
I1
;
#endif
#if 0
using in_data_t = float;
using wei_data_t = float;
using acc_data_t = float;
using out_data_t = float;
#elif
1
using
in_data_t
=
half_t
;
using
out_data_t
=
half_t
;
using
acc_data_t
=
float
;
using
wei_data_t
=
float
;
#elif 1
using
in_data_t
=
int8_t
;
using
out_data_t
=
int8_t
;
using
acc_data_t
=
int32_t
;
using
wei_data_t
=
int8_t
;
#endif
std
::
vector
<
std
::
size_t
>
in_lengths_host
(
4
),
wei_lengths_host
(
4
),
out_lengths_host
(
4
);
if
(
layout
==
ConvTensorLayout
::
NCHW
)
{
in_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
in_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
C
);
in_lengths_host
[
2
]
=
static_cast
<
std
::
size_t
>
(
Hi
);
in_lengths_host
[
3
]
=
static_cast
<
std
::
size_t
>
(
Wi
);
wei_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
wei_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
C
);
wei_lengths_host
[
2
]
=
static_cast
<
std
::
size_t
>
(
Y
);
wei_lengths_host
[
3
]
=
static_cast
<
std
::
size_t
>
(
X
);
out_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
out_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
K
);
out_lengths_host
[
2
]
=
static_cast
<
std
::
size_t
>
(
Ho
);
out_lengths_host
[
3
]
=
static_cast
<
std
::
size_t
>
(
Wo
);
}
else
if
(
layout
==
ConvTensorLayout
::
NHWC
)
{
in_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
in_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
Hi
);
in_lengths_host
[
2
]
=
static_cast
<
std
::
size_t
>
(
Wi
);
in_lengths_host
[
3
]
=
static_cast
<
std
::
size_t
>
(
C
);
wei_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
wei_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
Y
);
wei_lengths_host
[
2
]
=
static_cast
<
std
::
size_t
>
(
X
);
wei_lengths_host
[
3
]
=
static_cast
<
std
::
size_t
>
(
C
);
out_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
out_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
Ho
);
out_lengths_host
[
2
]
=
static_cast
<
std
::
size_t
>
(
Wo
);
out_lengths_host
[
3
]
=
static_cast
<
std
::
size_t
>
(
K
);
}
else
{
std
::
runtime_error
(
"wrong! not implemented"
);
}
Tensor
<
in_data_t
>
in
(
in_lengths_host
);
Tensor
<
wei_data_t
>
wei_device
(
wei_lengths_host
);
Tensor
<
wei_data_t
>
wei_host
(
wei_lengths_host
);
Tensor
<
out_data_t
>
out
(
out_lengths_host
);
std
::
cout
<<
"layout: "
<<
layout
<<
std
::
endl
;
ostream_HostTensorDescriptor
(
in
.
mDesc
,
std
::
cout
<<
"in: "
);
ostream_HostTensorDescriptor
(
wei_host
.
mDesc
,
std
::
cout
<<
"wei: "
);
ostream_HostTensorDescriptor
(
out
.
mDesc
,
std
::
cout
<<
"out: "
);
print_array
(
"InLeftPads"
,
make_tuple
(
in_left_pad_h
,
in_left_pad_w
));
print_array
(
"InRightPads"
,
make_tuple
(
in_right_pad_h
,
in_right_pad_w
));
print_array
(
"ConvStrides"
,
make_tuple
(
conv_stride_h
,
conv_stride_w
));
print_array
(
"ConvDilations"
,
make_tuple
(
conv_dilation_h
,
conv_dilation_w
));
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
switch
(
init_method
)
{
case
0
:
// no initialization
break
;
case
1
:
in
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
out
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
break
;
case
2
:
in
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
out
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
break
;
case
3
:
in
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
out
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
break
;
case
4
:
in
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
out
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
break
;
case
5
:
in
.
GenerateTensorValue
(
GeneratorTensor_3
<
float
>
{
-
0.1
,
0.1
},
num_thread
);
out
.
GenerateTensorValue
(
GeneratorTensor_3
<
float
>
{
-
0.1
,
0.1
},
num_thread
);
break
;
default:
in
.
GenerateTensorValue
(
GeneratorTensor_2
{
1
,
5
},
num_thread
);
auto
gen_out
=
[](
auto
...
is
)
{
return
GeneratorTensor_2
{
1
,
5
}(
is
...)
*
GeneratorTensor_Checkboard
{}(
is
...);
};
out
.
GenerateTensorValue
(
gen_out
,
num_thread
);
}
auto
f_make_for_device_nchw
=
[
&
]()
{
const
auto
in_lengths_dev
=
make_tuple
(
N
,
C
,
Hi
,
Wi
);
const
auto
wei_lengths_dev
=
make_tuple
(
K
,
C
,
Y
,
X
);
const
auto
out_lengths_dev
=
make_tuple
(
N
,
K
,
Ho
,
Wo
);
const
auto
conv_strides_dev
=
make_tuple
(
conv_stride_h
,
conv_stride_w
);
const
auto
conv_dilations_dev
=
make_tuple
(
conv_dilation_h
,
conv_dilation_w
);
const
auto
in_left_pads_dev
=
make_tuple
(
in_left_pad_h
,
in_left_pad_w
);
const
auto
in_right_pads_dev
=
make_tuple
(
in_right_pad_h
,
in_right_pad_w
);
return
make_tuple
(
in_lengths_dev
,
wei_lengths_dev
,
out_lengths_dev
,
conv_strides_dev
,
conv_dilations_dev
,
in_left_pads_dev
,
in_right_pads_dev
);
};
auto
f_make_for_device_nhwc
=
[
&
]()
{
const
auto
in_lengths_dev
=
make_tuple
(
N
,
Hi
,
Wi
,
C
);
const
auto
wei_lengths_dev
=
make_tuple
(
K
,
Y
,
X
,
C
);
const
auto
out_lengths_dev
=
make_tuple
(
N
,
Ho
,
Wo
,
K
);
const
auto
conv_strides_dev
=
make_tuple
(
conv_stride_h
,
conv_stride_w
);
const
auto
conv_dilations_dev
=
make_tuple
(
conv_dilation_h
,
conv_dilation_w
);
const
auto
in_left_pads_dev
=
make_tuple
(
in_left_pad_h
,
in_left_pad_w
);
const
auto
in_right_pads_dev
=
make_tuple
(
in_right_pad_h
,
in_right_pad_w
);
return
make_tuple
(
in_lengths_dev
,
wei_lengths_dev
,
out_lengths_dev
,
conv_strides_dev
,
conv_dilations_dev
,
in_left_pads_dev
,
in_right_pads_dev
);
};
// set zero to wei_device
wei_device
.
GenerateTensorValue
(
GeneratorTensor_0
{},
num_thread
);
#if USE_CONV_WRW_V4R4R2_XDL_NCHW
if
(
algo
==
ConvBackwardWeightAlgo
::
V4R4R2XDLNCHW
)
{
if
(
layout
!=
ConvTensorLayout
::
NCHW
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
const
auto
tmp
=
f_make_for_device_nchw
();
device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
<
in_data_t
,
wei_data_t
,
acc_data_t
,
out_data_t
>
(
tmp
[
I0
],
tmp
[
I1
],
tmp
[
I2
],
tmp
[
I3
],
tmp
[
I4
],
tmp
[
I5
],
tmp
[
I6
],
in
,
wei_device
,
out
,
nrepeat
);
}
#endif
#if USE_CONV_WRW_V4R4R4_XDL_NHWC
if
(
algo
==
ConvBackwardWeightAlgo
::
V4R4R4XDLNHWC
)
{
if
(
layout
!=
ConvTensorLayout
::
NHWC
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
const
auto
tmp
=
f_make_for_device_nhwc
();
device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
<
in_data_t
,
wei_data_t
,
acc_data_t
,
out_data_t
>
(
tmp
[
I0
],
tmp
[
I1
],
tmp
[
I2
],
tmp
[
I3
],
tmp
[
I4
],
tmp
[
I5
],
tmp
[
I6
],
in
,
wei_device
,
out
,
nrepeat
);
}
#endif
#if USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW
if
(
algo
==
ConvBackwardWeightAlgo
::
V4R4R2XDLATOMICNCHW
)
{
if
(
layout
!=
ConvTensorLayout
::
NCHW
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
const
auto
tmp
=
f_make_for_device_nchw
();
device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw
<
in_data_t
,
wei_data_t
,
acc_data_t
,
out_data_t
>
(
tmp
[
I0
],
tmp
[
I1
],
tmp
[
I2
],
tmp
[
I3
],
tmp
[
I4
],
tmp
[
I5
],
tmp
[
I6
],
in
,
wei_device
,
out
,
desired_grid_size
,
nrepeat
);
}
#endif
#if USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC
if
(
algo
==
ConvBackwardWeightAlgo
::
V4R4R4XDLATOMICNHWC
)
{
if
(
layout
!=
ConvTensorLayout
::
NHWC
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
const
auto
tmp
=
f_make_for_device_nhwc
();
device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk
<
in_data_t
,
wei_data_t
,
acc_data_t
,
out_data_t
>
(
tmp
[
I0
],
tmp
[
I1
],
tmp
[
I2
],
tmp
[
I3
],
tmp
[
I4
],
tmp
[
I5
],
tmp
[
I6
],
in
,
wei_device
,
out
,
desired_grid_size
,
nrepeat
);
}
#endif
#if USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC
if
(
algo
==
ConvBackwardWeightAlgo
::
V4R4R5XDLATOMICNHWC
)
{
if
(
layout
!=
ConvTensorLayout
::
NHWC
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
const
auto
tmp
=
f_make_for_device_nhwc
();
device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk
<
in_data_t
,
wei_data_t
,
acc_data_t
,
out_data_t
>
(
tmp
[
I0
],
tmp
[
I1
],
tmp
[
I2
],
tmp
[
I3
],
tmp
[
I4
],
tmp
[
I5
],
tmp
[
I6
],
in
,
wei_device
,
out
,
desired_grid_size
,
nrepeat
);
}
#endif
if
(
do_verification
)
{
host_direct_convolution_backward_weights
(
out
,
in
,
wei_host
,
make_tuple
(
conv_stride_h
,
conv_stride_w
),
make_tuple
(
conv_dilation_h
,
conv_dilation_w
),
make_tuple
(
in_left_pad_h
,
in_left_pad_w
),
make_tuple
(
in_right_pad_h
,
in_right_pad_w
),
layout
);
check_error
(
wei_host
,
wei_device
);
if
(
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"out: "
,
out
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"in : "
,
in
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"wei_device: "
,
wei_device
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"wei_host : "
,
wei_host
.
mData
,
","
)
<<
std
::
endl
;
}
}
}
host/driver_offline/src/gemm_driver_offline.cpp
0 → 100644
View file @
211dae82
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "config.hpp"
#include "debug.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "gemm_common.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdlops_mk_kn_mn.hpp"
#include "device_gemm_xdlops_mk_nk_mn.hpp"
#include "device_gemm_xdlops_km_kn_mn.hpp"
#include "device_gemm_xdlops_km_nk_mn.hpp"
#include "device_gemm_xdlops_mk_kn_nm.hpp"
#include "device_gemm_xdlops_mk_nk_nm.hpp"
#include "device_gemm_xdlops_km_kn_nm.hpp"
#include "device_gemm_xdlops_km_nk_nm.hpp"
#define USE_GEMM_XDL_MK_KN_MN 1
#define USE_GEMM_XDL_MK_NK_MN 1
#define USE_GEMM_XDL_KM_KN_MN 1
#define USE_GEMM_XDL_KM_NK_MN 1
#define USE_GEMM_XDL_MK_KN_NM 0
#define USE_GEMM_XDL_MK_NK_NM 0
#define USE_GEMM_XDL_KM_KN_NM 0
#define USE_GEMM_XDL_KM_NK_NM 0
enum
GemmAlgo
{
Xdl_MK_KN_MN
,
// 0
Xdl_MK_NK_MN
,
// 1
Xdl_KM_KN_MN
,
// 2
Xdl_KM_NK_MN
,
// 3
Xdl_MK_KN_NM
,
// 4
Xdl_MK_NK_NM
,
// 5
Xdl_KM_KN_NM
,
// 6
Xdl_KM_NK_NM
,
// 7
};
int
main
(
int
argc
,
char
*
argv
[])
{
using
namespace
ck
;
if
(
argc
!=
12
)
{
printf
(
"arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat
\n
"
);
printf
(
"rest: M, N, K
\n
"
);
printf
(
"debug_driver_gemm_xdlops_v2r3::M01, debug_driver_gemm_xdlops_v2r3::N01
\n
"
);
exit
(
1
);
}
const
auto
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
1
]));
const
auto
algo
=
static_cast
<
GemmAlgo
>
(
std
::
stoi
(
argv
[
2
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
3
]);
const
int
init_method
=
std
::
stoi
(
argv
[
4
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
5
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
6
]);
const
index_t
M
=
std
::
stoi
(
argv
[
7
]);
const
index_t
N
=
std
::
stoi
(
argv
[
8
]);
const
index_t
K
=
std
::
stoi
(
argv
[
9
]);
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
=
std
::
stoi
(
argv
[
10
]);
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
=
std
::
stoi
(
argv
[
11
]);
#if 0
using ab_data_t = float;
using acc_data_t = float;
using c_data_t = float;
#elif
1
using
ab_data_t
=
half_t
;
using
acc_data_t
=
float
;
using
c_data_t
=
half_t
;
#elif 1
using
ab_data_t
=
int8_t
;
using
acc_data_t
=
int32_t
;
using
c_data_t
=
int8_t
;
#endif
std
::
vector
<
std
::
size_t
>
a_lengths_host
(
2
),
b_lengths_host
(
2
),
c_lengths_host
(
2
);
std
::
vector
<
std
::
size_t
>
a_strides_host
(
2
),
b_strides_host
(
2
),
c_strides_host
(
2
);
// A
if
(
layout
==
GemmMatrixLayout
::
MK_KN_MN
||
layout
==
GemmMatrixLayout
::
MK_NK_MN
||
layout
==
GemmMatrixLayout
::
MK_KN_NM
||
layout
==
GemmMatrixLayout
::
MK_NK_NM
)
{
a_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
M
);
a_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
K
);
a_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
a_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
}
else
{
a_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
a_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
M
);
a_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
M
);
a_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
}
// B
if
(
layout
==
GemmMatrixLayout
::
MK_NK_MN
||
layout
==
GemmMatrixLayout
::
KM_NK_MN
||
layout
==
GemmMatrixLayout
::
MK_NK_NM
||
layout
==
GemmMatrixLayout
::
KM_NK_NM
)
{
b_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
b_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
K
);
b_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
b_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
}
else
{
b_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
b_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
N
);
b_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
b_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
}
// C
if
(
layout
==
GemmMatrixLayout
::
MK_KN_MN
||
layout
==
GemmMatrixLayout
::
KM_KN_MN
||
layout
==
GemmMatrixLayout
::
MK_NK_MN
||
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
c_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
M
);
c_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
N
);
c_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
c_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
}
else
{
c_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
c_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
M
);
c_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
M
);
c_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
}
Tensor
<
ab_data_t
>
a
(
a_lengths_host
,
a_strides_host
);
Tensor
<
ab_data_t
>
b
(
b_lengths_host
,
b_strides_host
);
Tensor
<
c_data_t
>
c_host
(
c_lengths_host
,
c_strides_host
);
Tensor
<
c_data_t
>
c_device
(
c_lengths_host
,
c_strides_host
);
std
::
cout
<<
"layout: "
<<
layout
<<
std
::
endl
;
ostream_HostTensorDescriptor
(
a
.
mDesc
,
std
::
cout
<<
"a: "
);
ostream_HostTensorDescriptor
(
b
.
mDesc
,
std
::
cout
<<
"b: "
);
ostream_HostTensorDescriptor
(
c_host
.
mDesc
,
std
::
cout
<<
"c: "
);
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
switch
(
init_method
)
{
case
0
:
// no initialization
break
;
case
1
:
a
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
b
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
break
;
case
2
:
a
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
b
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
break
;
case
3
:
a
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
b
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
break
;
case
4
:
a
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
b
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
break
;
default:
a
.
GenerateTensorValue
(
GeneratorTensor_3
<
float
>
{
0.0
,
1.0
},
num_thread
);
b
.
GenerateTensorValue
(
GeneratorTensor_3
<
float
>
{
-
0.5
,
0.5
},
num_thread
);
}
#if USE_GEMM_XDL_MK_KN_MN
if
(
algo
==
GemmAlgo
::
Xdl_MK_KN_MN
)
{
if
(
layout
!=
GemmMatrixLayout
::
MK_KN_MN
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
device_gemm_xdlops_mk_kn_mn
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
a
,
b
,
c_device
,
nrepeat
);
}
#endif
#if USE_GEMM_XDL_MK_NK_MN
if
(
algo
==
GemmAlgo
::
Xdl_MK_NK_MN
)
{
if
(
layout
!=
GemmMatrixLayout
::
MK_NK_MN
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
device_gemm_xdlops_mk_nk_mn
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
a
,
b
,
c_device
,
nrepeat
);
}
#endif
#if USE_GEMM_XDL_KM_KN_MN
if
(
algo
==
GemmAlgo
::
Xdl_KM_KN_MN
)
{
if
(
layout
!=
GemmMatrixLayout
::
KM_KN_MN
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
device_gemm_xdlops_km_kn_mn
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
a
,
b
,
c_device
,
nrepeat
);
}
#endif
#if USE_GEMM_XDL_KM_NK_MN
if
(
algo
==
GemmAlgo
::
Xdl_KM_NK_MN
)
{
if
(
layout
!=
GemmMatrixLayout
::
KM_NK_MN
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
device_gemm_xdlops_km_nk_mn
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
a
,
b
,
c_device
,
nrepeat
);
}
#endif
#if USE_GEMM_XDL_MK_KN_NM
if
(
algo
==
GemmAlgo
::
Xdl_MK_KN_NM
)
{
if
(
layout
!=
GemmMatrixLayout
::
MK_KN_NM
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
device_gemm_xdlops_mk_kn_nm
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
a
,
b
,
c_device
,
nrepeat
);
}
#endif
#if USE_GEMM_XDL_MK_NK_NM
if
(
algo
==
GemmAlgo
::
Xdl_MK_NK_NM
)
{
if
(
layout
!=
GemmMatrixLayout
::
MK_NK_NM
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
device_gemm_xdlops_mk_nk_nm
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
a
,
b
,
c_device
,
nrepeat
);
}
#endif
#if USE_GEMM_XDL_KM_KN_NM
if
(
algo
==
GemmAlgo
::
Xdl_KM_KN_NM
)
{
if
(
layout
!=
GemmMatrixLayout
::
KM_KN_NM
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
device_gemm_xdlops_km_kn_nm
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
a
,
b
,
c_device
,
nrepeat
);
}
#endif
#if USE_GEMM_XDL_KM_NK_NM
if
(
algo
==
GemmAlgo
::
Xdl_KM_NK_NM
)
{
if
(
layout
!=
GemmMatrixLayout
::
KM_NK_NM
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
device_gemm_xdlops_km_nk_nm
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
a
,
b
,
c_device
,
nrepeat
);
}
#endif
if
(
do_verification
)
{
host_gemm
(
a
,
b
,
c_host
,
layout
);
check_error
(
c_host
,
c_device
);
if
(
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"a : "
,
a
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"b: "
,
b
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_host : "
,
c_host
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_device: "
,
c_device
.
mData
,
","
)
<<
std
::
endl
;
}
}
}
host/host_tensor/include/device.hpp
View file @
211dae82
...
...
@@ -2,6 +2,9 @@
#define DEVICE_HPP
#include <memory>
#include <functional>
#include <thread>
#include <chrono>
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
...
...
@@ -74,7 +77,8 @@ float launch_and_time_kernel(
timer
.
End
();
// std::this_thread::sleep_for (std::chrono::microseconds(10));
return
timer
.
GetElapsedTime
()
/
nrepeat
;
}
#endif
host/host_tensor/include/gemm_common.hpp
0 → 100644
View file @
211dae82
#ifndef GEMM_COMMON_HPP
#define GEMM_COMMON_HPP
enum
GemmMatrixLayout
{
MK_KN_MN
,
// 0
MK_NK_MN
,
// 1
KM_KN_MN
,
// 2
KM_NK_MN
,
// 3
MK_KN_NM
,
// 4
MK_NK_NM
,
// 5
KM_KN_NM
,
// 6
KM_NK_NM
,
// 7
};
#endif
host/host_tensor/include/host_conv_bwd_weight.hpp
0 → 100644
View file @
211dae82
#pragma once
#include "host_tensor.hpp"
template
<
typename
TOut
,
typename
TIn
,
typename
TWei
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
host_direct_convolution_backward_weights
(
const
Tensor
<
TOut
>&
out
,
const
Tensor
<
TIn
>&
in
,
Tensor
<
TWei
>&
wei
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
,
const
ConvTensorLayout
layout
=
ConvTensorLayout
::
NCHW
)
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
auto
f_kcyx
=
[
&
](
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
double
v
=
0
;
for
(
int
n
=
0
;
n
<
out
.
mDesc
.
GetLengths
()[
0
];
++
n
)
{
for
(
int
ho
=
0
;
ho
<
out
.
mDesc
.
GetLengths
()[
2
];
++
ho
)
{
int
hi
=
ho
*
conv_strides
[
I0
]
+
y
*
conv_dilations
[
I0
]
-
in_left_pads
[
I0
];
for
(
int
wo
=
0
;
wo
<
out
.
mDesc
.
GetLengths
()[
3
];
++
wo
)
{
int
wi
=
wo
*
conv_strides
[
I1
]
+
x
*
conv_dilations
[
I1
]
-
in_left_pads
[
I1
];
if
(
hi
>=
0
&&
hi
<
in
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
wi
<
in
.
mDesc
.
GetLengths
()[
3
])
{
v
+=
static_cast
<
const
double
>
(
in
(
n
,
c
,
hi
,
wi
))
*
static_cast
<
const
double
>
(
out
(
n
,
k
,
ho
,
wo
));
}
}
}
}
wei
(
k
,
c
,
y
,
x
)
=
v
;
};
auto
f_kyxc
=
[
&
](
auto
k
,
auto
y
,
auto
x
,
auto
c
)
{
double
v
=
0
;
for
(
int
n
=
0
;
n
<
out
.
mDesc
.
GetLengths
()[
0
];
++
n
)
{
for
(
int
ho
=
0
;
ho
<
out
.
mDesc
.
GetLengths
()[
1
];
++
ho
)
{
int
hi
=
ho
*
conv_strides
[
I0
]
+
y
*
conv_dilations
[
I0
]
-
in_left_pads
[
I0
];
for
(
int
wo
=
0
;
wo
<
out
.
mDesc
.
GetLengths
()[
2
];
++
wo
)
{
int
wi
=
wo
*
conv_strides
[
I1
]
+
x
*
conv_dilations
[
I1
]
-
in_left_pads
[
I1
];
if
(
hi
>=
0
&&
hi
<
in
.
mDesc
.
GetLengths
()[
1
]
&&
wi
>=
0
&&
wi
<
in
.
mDesc
.
GetLengths
()[
2
])
{
v
+=
static_cast
<
const
double
>
(
in
(
n
,
hi
,
wi
,
c
))
*
static_cast
<
const
double
>
(
out
(
n
,
ho
,
wo
,
k
));
}
}
}
}
wei
(
k
,
y
,
x
,
c
)
=
v
;
};
if
(
layout
==
ConvTensorLayout
::
NCHW
)
{
make_ParallelTensorFunctor
(
f_kcyx
,
wei
.
mDesc
.
GetLengths
()[
0
],
wei
.
mDesc
.
GetLengths
()[
1
],
wei
.
mDesc
.
GetLengths
()[
2
],
wei
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
}
else
if
(
layout
==
ConvTensorLayout
::
NHWC
)
{
make_ParallelTensorFunctor
(
f_kyxc
,
wei
.
mDesc
.
GetLengths
()[
0
],
wei
.
mDesc
.
GetLengths
()[
1
],
wei
.
mDesc
.
GetLengths
()[
2
],
wei
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
}
else
{
throw
std
::
runtime_error
(
"wrong! not supported layout"
);
}
}
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment