Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ed068043
Commit
ed068043
authored
Nov 15, 2021
by
Jing Zhang
Browse files
merged develop
parents
41852668
e823d518
Changes
74
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2474 additions
and
881 deletions
+2474
-881
host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp
...ackward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp
+3
-2
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp
...ght_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp
+256
-0
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
...ard_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
+11
-10
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
...ght_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
+288
-0
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
...ard_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
+276
-0
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp
...ght_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp
+456
-0
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp
...ion_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp
+8
-8
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
...on_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
+205
-19
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
+137
-62
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
+278
-0
host/driver_offline/src/conv_bwd_driver_offline.cpp
host/driver_offline/src/conv_bwd_driver_offline.cpp
+149
-9
host/driver_offline/src/conv_fwd_driver_offline.cpp
host/driver_offline/src/conv_fwd_driver_offline.cpp
+104
-9
host/driver_offline/src/conv_wrw_driver_offline.cpp
host/driver_offline/src/conv_wrw_driver_offline.cpp
+264
-16
host/host_tensor/include/conv_common.hpp
host/host_tensor/include/conv_common.hpp
+0
-9
host/host_tensor/include/device.hpp
host/host_tensor/include/device.hpp
+1
-1
host/host_tensor/include/host_conv.hpp
host/host_tensor/include/host_conv.hpp
+12
-511
host/host_tensor/include/host_conv_bwd_data.hpp
host/host_tensor/include/host_conv_bwd_data.hpp
+0
-135
host/host_tensor/include/host_conv_bwd_weight.hpp
host/host_tensor/include/host_conv_bwd_weight.hpp
+0
-89
host/host_tensor/include/host_gemm.hpp
host/host_tensor/include/host_gemm.hpp
+23
-0
host/host_tensor/include/host_tensor.hpp
host/host_tensor/include/host_tensor.hpp
+3
-1
No files found.
host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp
View file @
ed068043
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp"
#include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
#include "debug.hpp"
template
<
typename
TInWei
,
template
<
typename
TInWei
,
typename
TAcc
,
typename
TAcc
,
...
@@ -275,8 +276,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
...
@@ -275,8 +276,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
wei_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmk0_gemmn_gemmk1_grid_desc
,
out_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
,
in_gemmm_gemmn_grid_desc
,
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug_driver_gemm_xdlops_v2r3
::
N01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
wei_gemmk0_gemmm_gemmk1_grid_step_hacks
,
wei_gemmk0_gemmm_gemmk1_grid_step_hacks
,
out_gemmk0_gemmn_gemmk1_grid_step_hacks
,
out_gemmk0_gemmn_gemmk1_grid_step_hacks
,
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
...
...
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
ed068043
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp"
#include "driver_gemm_xdlops_v2r4.hpp"
template
<
typename
TIn
,
typename
TWei
,
typename
TAcc
,
typename
TOut
,
typename
InLengths
,
typename
WeiLengths
,
typename
OutLengths
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
typename
GridSizeType
>
void
device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw
(
const
InLengths
&
in_n_c_hi_wi_lengths
,
const
WeiLengths
&
wei_k_c_y_x_lengths
,
const
OutLengths
&
out_n_k_ho_wo_lengths
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
Tensor
<
TIn
>&
in_n_c_hi_wi
,
Tensor
<
TWei
>&
wei_k_c_y_x
,
const
Tensor
<
TOut
>&
out_n_k_ho_wo
,
GridSizeType
desired_grid_size
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
DeviceMem
in_n_c_hi_wi_device_buf
(
sizeof
(
TIn
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c_y_x_device_buf
(
sizeof
(
TWei
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_k_ho_wo_device_buf
(
sizeof
(
TOut
)
*
out_n_k_ho_wo
.
mDesc
.
GetElementSpace
());
in_n_c_hi_wi_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
wei_k_c_y_x_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
out_n_k_ho_wo_device_buf
.
ToDevice
(
out_n_k_ho_wo
.
mData
.
data
());
const
auto
in_n_c_hi_wi_desc
=
make_naive_tensor_descriptor_packed
(
in_n_c_hi_wi_lengths
);
const
auto
wei_k_c_y_x_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_c_y_x_lengths
);
const
auto
out_n_k_ho_wo_desc
=
make_naive_tensor_descriptor_packed
(
out_n_k_ho_wo_lengths
);
#if 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmB_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmB_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
64
,
1
>
;
// using vector load 4, so config's wo*ho must be a multiple of 4
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
4
;
using
GemmBBlockTransferThreadSliceLengths_GemmB_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmB_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#endif
const
auto
N
=
in_n_c_hi_wi_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_c_hi_wi_desc
.
GetLength
(
I1
);
const
auto
K
=
out_n_k_ho_wo_desc
.
GetLength
(
I1
);
const
auto
Ho
=
out_n_k_ho_wo_desc
.
GetLength
(
I2
);
const
auto
Wo
=
out_n_k_ho_wo_desc
.
GetLength
(
I3
);
const
auto
Y
=
wei_k_c_y_x_desc
.
GetLength
(
I2
);
const
auto
X
=
wei_k_c_y_x_desc
.
GetLength
(
I3
);
const
auto
GemmM
=
K
;
const
auto
GemmN
=
Y
*
X
*
C
;
const
auto
GemmKTotal
=
N
*
Ho
*
Wo
;
const
auto
GridMN
=
GemmM
*
GemmN
/
(
GemmMPerBlock
*
GemmNPerBlock
);
const
index_t
GemmKBatch
=
std
::
max
(
desired_grid_size
/
GridMN
,
1
);
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1
*
GemmKPerBlock
*
GemmKBatch
)
*
GemmKPerBlock
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1
;
std
::
cout
<<
"GemmKTotal: "
<<
GemmKTotal
<<
" GrideSizeMN: "
<<
GridMN
<<
" GemmKBatch: "
<<
GemmKBatch
<<
" GemmK0: "
<<
GemmK0
<<
" gemmKPad: "
<<
GemmKPad
<<
std
::
endl
;
const
auto
descs
=
transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad
(
wei_k_c_y_x_desc
,
in_n_c_hi_wi_desc
,
out_n_k_ho_wo_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
Number
<
GemmK1
>
{},
GemmKBatch
,
GemmKPad
);
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
in_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
const
auto
wei_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmB
Sequence
<
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmM
Sequence
<
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{}),
// 3+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 0-: GemB
Sequence
<
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: GemmM
Sequence
<
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmB
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{}),
// 3+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmB
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
constexpr
auto
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
1
,
0
,
0
,
0
,
0
>
{};
const
auto
driver_gemm_xdlops
=
driver_gemm_xdlops_v2r4
<
BlockSize
,
TIn
,
TAcc
,
TWei
,
InMemoryDataOperationEnum_t
::
AtomicAdd
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_desc
),
decltype
(
wei_gemmm_gemmn_grid_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerWave
,
GemmNPerWave
,
GemmK1
,
MRepeat
,
NRepeat
,
GemmABlockTransferThreadSliceLengths_GemmB_GemmK0_GemmM_GemmK1
,
GemmABlockTransferThreadClusterLengths_GemmB_GemmK0_GemmM_GemmK1
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
3
,
GemmABlockTransferSrcScalarPerVector_GemmK1
,
GemmABlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmB_GemmK0_GemmN_GemmK1
,
GemmBBlockTransferThreadClusterLengths_GemmB_GemmK0_GemmN_GemmK1
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
3
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
3
,
0
,
1
,
2
,
7
,
5
,
4
,
6
>
,
7
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_step_hacks
),
decltype
(
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
false
,
true
,
true
>
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops
(
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TIn
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
out_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
out_gemmk0_gemmm_gemmk1_grid_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
float
perf
=
static_cast
<
float
>
(
calculate_convolution_flops
(
in_n_c_hi_wi_desc
,
wei_k_c_y_x_desc
,
out_n_k_ho_wo_desc
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
wei_k_c_y_x_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
driver_gemm_xdlops
(
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TIn
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
out_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
out_gemmk0_gemmm_gemmk1_grid_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
0
);
// copy result back to host
wei_k_c_y_x_device_buf
.
FromDevice
(
wei_k_c_y_x
.
mData
.
data
());
}
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
View file @
ed068043
...
@@ -4,7 +4,8 @@
...
@@ -4,7 +4,8 @@
#include "transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
TInWei
,
template
<
typename
TIn
,
typename
TWei
,
typename
TAcc
,
typename
TAcc
,
typename
TOut
,
typename
TOut
,
typename
InLengths
,
typename
InLengths
,
...
@@ -22,8 +23,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
...
@@ -22,8 +23,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
const
ConvDilations
&
conv_dilations
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
InRightPads
&
in_right_pads
,
const
Tensor
<
TIn
Wei
>&
in_n_c_hi_wi
,
const
Tensor
<
TIn
>&
in_n_c_hi_wi
,
Tensor
<
T
In
Wei
>&
wei_k_c_y_x
,
Tensor
<
TWei
>&
wei_k_c_y_x
,
const
Tensor
<
TOut
>&
out_n_k_ho_wo
,
const
Tensor
<
TOut
>&
out_n_k_ho_wo
,
ck
::
index_t
nrepeat
)
ck
::
index_t
nrepeat
)
{
{
...
@@ -35,8 +36,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
...
@@ -35,8 +36,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
DeviceMem
in_n_c_hi_wi_device_buf
(
sizeof
(
TIn
Wei
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
in_n_c_hi_wi_device_buf
(
sizeof
(
TIn
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c_y_x_device_buf
(
sizeof
(
T
In
Wei
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c_y_x_device_buf
(
sizeof
(
TWei
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_k_ho_wo_device_buf
(
sizeof
(
TOut
)
*
out_n_k_ho_wo
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_k_ho_wo_device_buf
(
sizeof
(
TOut
)
*
out_n_k_ho_wo
.
mDesc
.
GetElementSpace
());
in_n_c_hi_wi_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
in_n_c_hi_wi_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
...
@@ -47,7 +48,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
...
@@ -47,7 +48,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
const
auto
wei_k_c_y_x_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_c_y_x_lengths
);
const
auto
wei_k_c_y_x_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_c_y_x_lengths
);
const
auto
out_n_k_ho_wo_desc
=
make_naive_tensor_descriptor_packed
(
out_n_k_ho_wo_lengths
);
const
auto
out_n_k_ho_wo_desc
=
make_naive_tensor_descriptor_packed
(
out_n_k_ho_wo_lengths
);
#if
1
#if
0
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t BlockSize = 256;
...
@@ -164,9 +165,9 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
...
@@ -164,9 +165,9 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
{
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
BlockSize
,
TIn
Wei
,
TIn
,
TAcc
,
TAcc
,
T
Out
,
T
Wei
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_desc
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_desc
),
...
@@ -207,8 +208,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
...
@@ -207,8 +208,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
true
,
// ABlockLdsExtraM
true
,
// ABlockLdsExtraM
true
// BBlockLdsExtraN
true
// BBlockLdsExtraN
>
(
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
>
(
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TIn
Wei
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TIn
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
In
Wei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
out_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
...
...
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
0 → 100644
View file @
ed068043
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r4.hpp"
template
<
typename
TIn
,
typename
TWei
,
typename
TAcc
,
typename
TOut
,
typename
InLengths
,
typename
WeiLengths
,
typename
OutLengths
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
typename
GridSizeType
>
void
device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk
(
const
InLengths
&
in_n_hi_wi_c_lengths
,
const
WeiLengths
&
wei_k_y_x_c_lengths
,
const
OutLengths
&
out_n_ho_wo_k_lengths
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
Tensor
<
TIn
>&
in_n_hi_wi_c
,
Tensor
<
TWei
>&
wei_k_y_x_c
,
const
Tensor
<
TOut
>&
out_n_ho_wo_k
,
GridSizeType
desired_grid_size
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
DeviceMem
in_n_hi_wi_c_device_buf
(
sizeof
(
TIn
)
*
in_n_hi_wi_c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_y_x_c_device_buf
(
sizeof
(
TWei
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_ho_wo_k_device_buf
(
sizeof
(
TOut
)
*
out_n_ho_wo_k
.
mDesc
.
GetElementSpace
());
in_n_hi_wi_c_device_buf
.
ToDevice
(
in_n_hi_wi_c
.
mData
.
data
());
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor_packed
(
in_n_hi_wi_c_lengths
);
const
auto
wei_k_y_x_c_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_y_x_c_lengths
);
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
#if 0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif
1
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
1
,
4
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
32
,
2
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
2
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
4
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
32
,
2
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
2
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#endif
const
auto
N
=
in_n_hi_wi_c_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_hi_wi_c_desc
.
GetLength
(
I3
);
const
auto
K
=
out_n_ho_wo_k_desc
.
GetLength
(
I3
);
const
auto
Ho
=
out_n_ho_wo_k_desc
.
GetLength
(
I1
);
const
auto
Wo
=
out_n_ho_wo_k_desc
.
GetLength
(
I2
);
const
auto
Y
=
wei_k_y_x_c_desc
.
GetLength
(
I1
);
const
auto
X
=
wei_k_y_x_c_desc
.
GetLength
(
I2
);
const
auto
GemmM
=
Y
*
X
*
C
;
const
auto
GemmN
=
K
;
const
auto
GemmKTotal
=
N
*
Ho
*
Wo
;
const
auto
GridMN
=
GemmM
*
GemmN
/
(
GemmMPerBlock
*
GemmNPerBlock
);
const
index_t
GemmKBatch
=
std
::
max
(
desired_grid_size
/
GridMN
,
1
);
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1
*
GemmKPerBlock
*
GemmKBatch
)
*
GemmKPerBlock
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1
;
std
::
cout
<<
"GemmKTotal: "
<<
GemmKTotal
<<
" GrideSizeMN: "
<<
GridMN
<<
" GemmKBatch: "
<<
GemmKBatch
<<
" GemmK0: "
<<
GemmK0
<<
" gemmKPad: "
<<
GemmKPad
<<
std
::
endl
;
const
auto
descs
=
transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad
(
in_n_hi_wi_c_desc
,
wei_k_y_x_c_desc
,
out_n_ho_wo_k_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
Number
<
GemmK1
>
{},
GemmKBatch
,
GemmKPad
);
const
auto
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
const
auto
wei_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmKBatch
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{}),
// 3+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmKBatch
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
constexpr
auto
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
1
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
const
auto
driver_gemm_xdlops
=
driver_gemm_xdlops_v2r4
<
BlockSize
,
TIn
,
TAcc
,
TWei
,
InMemoryDataOperationEnum_t
::
AtomicAdd
,
decltype
(
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
),
decltype
(
wei_gemmm_gemmn_grid_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerXDL
,
GemmNPerXDL
,
GemmK1
,
MRepeat
,
NRepeat
,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
2
,
GemmABlockTransferSrcScalarPerVector_GemmM
,
GemmABlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
2
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
6
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks
),
decltype
(
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
false
,
// CAccessOrderMRepeatNRepeat
true
,
true
>
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops
(
static_cast
<
TIn
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks
,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
{
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
}
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
driver_gemm_xdlops
(
static_cast
<
TIn
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks
,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
0
);
// copy result back to host
wei_k_y_x_c_device_buf
.
FromDevice
(
wei_k_y_x_c
.
mData
.
data
());
}
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
0 → 100644
View file @
ed068043
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
#include "debug.hpp"
template
<
typename
TIn
,
typename
TWei
,
typename
TAcc
,
typename
TOut
,
typename
InLengths
,
typename
WeiLengths
,
typename
OutLengths
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
(
const
InLengths
&
in_n_hi_wi_c_lengths
,
const
WeiLengths
&
wei_k_y_x_c_lengths
,
const
OutLengths
&
out_n_ho_wo_k_lengths
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
Tensor
<
TIn
>&
in_n_hi_wi_c
,
Tensor
<
TWei
>&
wei_k_y_x_c
,
const
Tensor
<
TOut
>&
out_n_ho_wo_k
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
DeviceMem
in_n_hi_wi_c_device_buf
(
sizeof
(
TIn
)
*
in_n_hi_wi_c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_y_x_c_device_buf
(
sizeof
(
TWei
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_ho_wo_k_device_buf
(
sizeof
(
TOut
)
*
out_n_ho_wo_k
.
mDesc
.
GetElementSpace
());
in_n_hi_wi_c_device_buf
.
ToDevice
(
in_n_hi_wi_c
.
mData
.
data
());
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor_packed
(
in_n_hi_wi_c_lengths
);
const
auto
wei_k_y_x_c_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_y_x_c_lengths
);
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif
1
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
32
,
2
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
2
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
32
,
2
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
2
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
4
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
32
,
2
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
4
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
32
,
2
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#endif
const
auto
descs
=
transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad
(
in_n_hi_wi_c_desc
,
wei_k_y_x_c_desc
,
out_n_ho_wo_k_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
Number
<
GemmK1
>
{});
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
out_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
const
auto
wei_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
in_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
// 1+: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{},
// 1-: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
out_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
1
,
0
,
0
>
{};
constexpr
auto
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
TIn
,
TAcc
,
TWei
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
in_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_desc
),
decltype
(
wei_gemmm_gemmn_grid_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerXDL
,
GemmNPerXDL
,
GemmK1
,
MRepeat
,
NRepeat
,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
GemmABlockTransferSrcScalarPerVector_GemmM
,
GemmABlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
7
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
in_gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_step_hacks
),
decltype
(
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
false
,
// CAccessOrderMRepeatNRepeat
true
,
true
>
(
static_cast
<
TIn
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
in_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
in_gemmk0_gemmm_gemmk1_grid_step_hacks
,
out_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
{
const
auto
N
=
out_n_ho_wo_k_lengths
[
I0
];
const
auto
K
=
out_n_ho_wo_k_lengths
[
I3
];
const
auto
C
=
wei_k_y_x_c_lengths
[
I3
];
const
auto
Ho
=
out_n_ho_wo_k_lengths
[
I1
];
const
auto
Wo
=
out_n_ho_wo_k_lengths
[
I2
];
const
auto
Y
=
wei_k_y_x_c_lengths
[
I1
];
const
auto
X
=
wei_k_y_x_c_lengths
[
I2
];
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
}
// copy result back to host
wei_k_y_x_c_device_buf
.
FromDevice
(
wei_k_y_x_c
.
mData
.
data
());
}
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp
0 → 100644
View file @
ed068043
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r4.hpp"
template
<
typename
TIn
,
typename
TWei
,
typename
TAcc
,
typename
TOut
,
typename
InLengths
,
typename
WeiLengths
,
typename
OutLengths
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
typename
GridSizeType
>
void
device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk
(
const
InLengths
&
in_n_hi_wi_c_lengths
,
const
WeiLengths
&
wei_k_y_x_c_lengths
,
const
OutLengths
&
out_n_ho_wo_k_lengths
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
Tensor
<
TIn
>&
in_n_hi_wi_c
,
Tensor
<
TWei
>&
wei_k_y_x_c
,
const
Tensor
<
TOut
>&
out_n_ho_wo_k
,
GridSizeType
desired_grid_size
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
DeviceMem
in_n_hi_wi_c_device_buf
(
sizeof
(
TIn
)
*
in_n_hi_wi_c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_y_x_c_device_buf
(
sizeof
(
TWei
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_ho_wo_k_device_buf
(
sizeof
(
TOut
)
*
out_n_ho_wo_k
.
mDesc
.
GetElementSpace
());
in_n_hi_wi_c_device_buf
.
ToDevice
(
in_n_hi_wi_c
.
mData
.
data
());
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor_packed
(
in_n_hi_wi_c_lengths
);
const
auto
wei_k_y_x_c_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_y_x_c_lengths
);
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4], C 128, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 4], C 128, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
1
,
4
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
32
,
2
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
2
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
8
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
32
,
2
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
2
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 4], C 64, for fp32 and fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
1
,
4
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
32
,
2
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
2
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
4
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
32
,
2
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
2
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8], C 128, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
1
,
16
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
16
,
4
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
2
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
8
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
16
,
4
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
2
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8], C 64, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
1
,
8
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
16
,
4
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
2
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
8
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
16
,
4
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
2
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C 64, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
1
,
16
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
8
,
4
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
2
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
8
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
8
,
4
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
2
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [64, 128, 4, 8], C 64, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
GemmMPerBlock
=
64
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
1
,
8
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
8
,
4
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
2
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
16
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
8
,
4
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
2
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [64, 64, 4, 8], C 32, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
GemmMPerBlock
=
64
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
1
,
8
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
8
,
4
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
2
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
8
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
8
,
4
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
2
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#endif
const
auto
N
=
in_n_hi_wi_c_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_hi_wi_c_desc
.
GetLength
(
I3
);
const
auto
K
=
out_n_ho_wo_k_desc
.
GetLength
(
I3
);
const
auto
Ho
=
out_n_ho_wo_k_desc
.
GetLength
(
I1
);
const
auto
Wo
=
out_n_ho_wo_k_desc
.
GetLength
(
I2
);
const
auto
Y
=
wei_k_y_x_c_desc
.
GetLength
(
I1
);
const
auto
X
=
wei_k_y_x_c_desc
.
GetLength
(
I2
);
const
auto
GemmM
=
K
;
const
auto
GemmN
=
Y
*
X
*
C
;
const
auto
GemmKTotal
=
N
*
Ho
*
Wo
;
const
auto
GridMN
=
GemmM
*
GemmN
/
(
GemmMPerBlock
*
GemmNPerBlock
);
const
index_t
GemmKBatch
=
std
::
max
(
desired_grid_size
/
GridMN
,
1
);
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1
*
GemmKPerBlock
*
GemmKBatch
)
*
GemmKPerBlock
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1
;
std
::
cout
<<
"GemmKTotal: "
<<
GemmKTotal
<<
" GrideSizeMN: "
<<
GridMN
<<
" GemmKBatch: "
<<
GemmKBatch
<<
" GemmK0: "
<<
GemmK0
<<
" gemmKPad: "
<<
GemmKPad
<<
std
::
endl
;
const
auto
descs
=
transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk_pad
(
in_n_hi_wi_c_desc
,
wei_k_y_x_c_desc
,
out_n_ho_wo_k_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
Number
<
GemmK1
>
{},
GemmKBatch
,
GemmKPad
);
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
const
auto
wei_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
1
,
0
,
0
,
0
,
0
>
{};
const
auto
driver_gemm_xdlops
=
driver_gemm_xdlops_v2r4
<
BlockSize
,
TIn
,
TAcc
,
TWei
,
InMemoryDataOperationEnum_t
::
AtomicAdd
,
decltype
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
),
decltype
(
wei_gemmm_gemmn_grid_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerXDL
,
GemmNPerXDL
,
GemmK1
,
MRepeat
,
NRepeat
,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
2
,
GemmABlockTransferSrcScalarPerVector_GemmM
,
GemmABlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
3
,
2
>
,
2
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
7
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks
),
decltype
(
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
false
,
// CAccessOrderMRepeatNRepeat
true
,
true
>
;
// timing
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops
(
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TIn
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
{
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
}
// verification
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
driver_gemm_xdlops
(
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TIn
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
0
);
// copy result back to host
wei_k_y_x_c_device_buf
.
FromDevice
(
wei_k_y_x_c
.
mData
.
data
());
}
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp
View file @
ed068043
...
@@ -141,14 +141,14 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
...
@@ -141,14 +141,14 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
#endif
#endif
const
auto
descs
=
const
auto
descs
=
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk
_pad
(
in_n_hi_wi_c_desc
,
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk
(
in_n_hi_wi_c_desc
,
wei_k_y_x_c_desc
,
wei_k_y_x_c_desc
,
out_n_ho_wo_k_desc
,
out_n_ho_wo_k_desc
,
conv_strides
,
conv_strides
,
conv_dilations
,
conv_dilations
,
in_left_pads
,
in_left_pads
,
in_right_pads
,
in_right_pads
,
Number
<
GemmK1
>
{});
Number
<
GemmK1
>
{});
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
...
...
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
View file @
ed068043
...
@@ -4,6 +4,131 @@
...
@@ -4,6 +4,131 @@
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
#if 0
__host__ __device__ static constexpr auto
MakePaddedGridDescriptors(const AGridDesc_K0Raw_MRaw_K1& a_grid_desc_k0raw_mraw_k1,
const BGridDesc_K0Raw_NRaw_K1& b_grid_desc_k0raw_nraw_k1,
const CGridDesc_MRaw_NRaw& c_grid_desc_mraw_nraw)
{
const auto K0Raw = a_grid_desc_k0raw_mraw_k1.GetLength(I0);
const auto K1 = a_grid_desc_k0raw_mraw_k1.GetLength(I2);
const auto MRaw = c_grid_desc_mraw_nraw.GetLength(I0);
const auto NRaw = c_grid_desc_mraw_nraw.GetLength(I1);
const auto K0Pad = math::integer_least_multiple(K0Raw, K0PerBlock) - K0Raw;
const auto MPad = math::integer_least_multiple(MRaw, MPerBlock) - MRaw;
const auto NPad = math::integer_least_multiple(NRaw, NPerBlock) - NRaw;
// A
const auto a_grid_desc_k0_m_k1 = [&]() {
if constexpr(DoPad_K0 && DoPad_M)
{
return transform_tensor_descriptor(
a_grid_desc_k0_m_k1,
make_tuple(make_right_pad_transform(K0Raw, K0Pad),
make_right_pad_transform(MRaw, MPad),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
else if constexpr(DoPad_K0 && !DoPad_M)
{
return transform_tensor_descriptor(
a_grid_desc_k0_m_k1,
make_tuple(make_right_pad_transform(K0Raw, K0Pad),
make_pass_through_transform(MRaw),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
else if constexpr(!DoPad_K0 && DoPad_M)
{
return transform_tensor_descriptor(
a_grid_desc_k0_m_k1,
make_tuple(make_pass_through_transform(K0Raw),
make_right_pad_transform(MRaw, MPad),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
else
{
return a_grid_desc_k0raw_mraw_k1;
}
}();
// B
const auto b_grid_desc_k0_n_k1 = [&]() {
if constexpr(DoPad_K0 && DoPad_N)
{
return transform_tensor_descriptor(
b_grid_desc_k0_n_k1,
make_tuple(make_right_pad_transform(K0Raw, K0Pad),
make_right_pad_transform(NRaw, NPad),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
else if constexpr(DoPad_K0 && !DoPad_N)
{
return transform_tensor_descriptor(
b_grid_desc_k0_n_k1,
make_tuple(make_right_pad_transform(K0Raw, K0Pad),
make_pass_through_transform(NRaw),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
else if constexpr(!DoPad_K0 && DoPad_N)
{
return transform_tensor_descriptor(
b_grid_desc_k0_n_k1,
make_tuple(make_pass_through_transform(K0Raw),
make_right_pad_transform(NRaw, NPad),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
else
{
return b_grid_desc_k0raw_nraw_k1;
}
}();
// C
const auto c_grid_desc_m_n = [&]() {
if constexpr(DoPad_M && DoPad_N)
{
return transform_tensor_descriptor(c_grid_desc_m_n,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(DoPad_M && !DoPad_N)
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(!DoPad_M && DoPad_N)
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
reutnr c_grid_desc_m_n;
}
}();
}
#endif
template
<
typename
TInWei
,
template
<
typename
TInWei
,
typename
TAcc
,
typename
TAcc
,
typename
TOut
,
typename
TOut
,
...
@@ -160,7 +285,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
...
@@ -160,7 +285,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif
0
#elif
1
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
@@ -188,7 +313,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
...
@@ -188,7 +313,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif
1
#elif
0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
@@ -275,20 +400,19 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
...
@@ -275,20 +400,19 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
#endif
#endif
const
auto
descs
=
const
auto
descs
=
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad
(
in_n_hi_wi_c_desc
,
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk
(
in_n_hi_wi_c_desc
,
wei_k_y_x_c_desc
,
wei_k_y_x_c_desc
,
out_n_ho_wo_k_desc
,
out_n_ho_wo_k_desc
,
conv_strides
,
conv_strides
,
conv_dilations
,
conv_dilations
,
in_left_pads
,
in_left_pads
,
in_right_pads
,
in_right_pads
,
Number
<
GemmK1
>
{});
Number
<
GemmK1
>
{});
#if 0 // debug
const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
const
auto
out_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A
, B, C
matrix
// HACK: hacks that control index calculation when iterating over A matrix
constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks =
constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmM
...
@@ -297,7 +421,39 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
...
@@ -297,7 +421,39 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1
constexpr
auto
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
=
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
#else
const
auto
in_gemmk0_gemmmraw_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
GemmK0
=
in_gemmk0_gemmmraw_gemmk1_grid_desc
.
GetLength
(
I0
);
const
auto
GemmMRaw
=
in_gemmk0_gemmmraw_gemmk1_grid_desc
.
GetLength
(
I1
);
const
auto
GemmMPad
=
math
::
integer_least_multiple
(
GemmMRaw
,
GemmMPerBlock
)
-
GemmMRaw
;
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmk0_gemmmraw_gemmk1_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmK0
),
make_right_pad_transform
(
GemmMRaw
,
GemmMPad
),
make_pass_through_transform
(
GemmK1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// HACK: hacks that control index calculation when iterating over A matrix
constexpr
auto
in_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
,
0
,
0
,
0
>
{};
#endif
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
const
auto
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: GemmK1
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: GemmK1
...
@@ -305,6 +461,12 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
...
@@ -305,6 +461,12 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
#if 0
const auto out_gemmm_gemmn_grid_desc = descs[I2];
constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
...
@@ -322,12 +484,36 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
...
@@ -322,12 +484,36 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
#else
const
auto
out_gemmmraw_gemmn_grid_desc
=
descs
[
I2
];
constexpr
auto
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
const
auto
GemmN
=
out_gemmmraw_gemmn_grid_desc
.
GetLength
(
I1
);
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
{};
constexpr
auto
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
=
const
auto
out_gemmm_gemmn_grid_desc
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
transform_tensor_descriptor
(
out_gemmmraw_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmMRaw
,
GemmMPad
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
constexpr
auto
out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
#endif
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
{
...
...
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
View file @
ed068043
...
@@ -11,8 +11,8 @@ template <ck::index_t BlockSize,
...
@@ -11,8 +11,8 @@ template <ck::index_t BlockSize,
typename
FloatAcc
,
typename
FloatAcc
,
typename
FloatC
,
typename
FloatC
,
ck
::
InMemoryDataOperationEnum_t
CGlobalMemoryDataOperation
,
ck
::
InMemoryDataOperationEnum_t
CGlobalMemoryDataOperation
,
typename
A
K0MK1
GridDesc
,
typename
AGridDesc
_K0_M_K1
,
typename
B
K0NK1
GridDesc
,
typename
BGridDesc
_K0_N_K
,
typename
CMNGridDesc
,
typename
CMNGridDesc
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
...
@@ -52,9 +52,9 @@ template <ck::index_t BlockSize,
...
@@ -52,9 +52,9 @@ template <ck::index_t BlockSize,
__host__
float
driver_gemm_xdlops_v2r3
(
const
FloatAB
*
p_a_grid
,
__host__
float
driver_gemm_xdlops_v2r3
(
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_grid
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
FloatC
*
p_c_grid
,
const
A
K0MK1
GridDesc
&
a_k0_m_k1
_grid_desc
,
const
AGridDesc
_K0_M_K1
&
a
_grid_desc
_k0_m_k1
,
const
B
K0NK1
GridDesc
&
b_k0_n_k1
_grid_desc
,
const
BGridDesc
_K0_N_K
&
b
_grid_desc
_k0_n_k1
,
const
CMNGridDesc
&
c_
m_n_
grid_desc
,
const
CMNGridDesc
&
c_grid_desc
_m_n
,
ck
::
index_t
M01
,
ck
::
index_t
M01
,
ck
::
index_t
N01
,
ck
::
index_t
N01
,
AGridStepHacks
,
AGridStepHacks
,
...
@@ -63,7 +63,6 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
...
@@ -63,7 +63,6 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
AGridMoveSliceWindowStepHacks
,
AGridMoveSliceWindowStepHacks
,
BGridMoveSliceWindowStepHacks
,
BGridMoveSliceWindowStepHacks
,
ck
::
index_t
nrepeat
)
ck
::
index_t
nrepeat
)
{
{
using
namespace
ck
;
using
namespace
ck
;
...
@@ -77,8 +76,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
...
@@ -77,8 +76,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
FloatAcc
,
FloatAcc
,
FloatC
,
FloatC
,
CGlobalMemoryDataOperation
,
CGlobalMemoryDataOperation
,
A
K0MK1
GridDesc
,
AGridDesc
_K0_M_K1
,
B
K0NK1
GridDesc
,
BGridDesc
_K0_N_K
,
CMNGridDesc
,
CMNGridDesc
,
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
...
@@ -117,84 +116,160 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
...
@@ -117,84 +116,160 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
BBlockLdsAddExtraN
>
;
BBlockLdsAddExtraN
>
;
{
{
std
::
cout
<<
"a_
k0_m_k1_
grid_desc
{"
<<
a
_k0_m_k1_grid_desc
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"a_grid_desc_k0_m_k1
{"
<<
a
_grid_desc
_k0_m_k1
.
GetLength
(
I0
)
<<
", "
<<
a_
k0_m_k1_
grid_desc
.
GetLength
(
I1
)
<<
", "
<<
a_
k0_m_k1_
grid_desc
.
GetLength
(
I2
)
<<
a_grid_desc
_k0_m_k1
.
GetLength
(
I1
)
<<
", "
<<
a_grid_desc
_k0_m_k1
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"b_
k0_n_k1_
grid_desc
{"
<<
b
_k0_n_k1_grid_desc
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"b_grid_desc_k0_n_k1
{"
<<
b
_grid_desc
_k0_n_k1
.
GetLength
(
I0
)
<<
", "
<<
b_
k0_n_k1_
grid_desc
.
GetLength
(
I1
)
<<
", "
<<
b_
k0_n_k1_
grid_desc
.
GetLength
(
I2
)
<<
b_grid_desc
_k0_n_k1
.
GetLength
(
I1
)
<<
", "
<<
b_grid_desc
_k0_n_k1
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"c_
m_n_
grid_desc{ "
<<
c_
m_n_
grid_desc
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"c_grid_desc
_m_n
{ "
<<
c_grid_desc
_m_n
.
GetLength
(
I0
)
<<
", "
<<
c_
m_n_
grid_desc
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
<<
c_grid_desc
_m_n
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
}
if
(
!
GridwiseGemm
::
CheckValidity
(
if
(
!
GridwiseGemm
::
CheckValidity
(
a_
k0_m_k1_
grid_desc
,
b
_k0_
n
_k1_grid_desc
,
c_m_n
_grid_desc
,
M01
,
N01
))
a_grid_desc_k0_
m
_k1
,
b
_grid_desc
_k0_n_k1
,
c
_grid_desc
_m_n
,
M01
,
N01
))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"
);
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"
);
}
}
const
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
=
const
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
=
GridwiseGemm
::
MakeCM0N0M1N1M2M3M4N2
GridDescriptor
(
c_m_n
_grid_desc
);
GridwiseGemm
::
MakeC
GridDescriptor_
M0
_
N0
_
M1
_
N1
_
M2
_
M3
_
M4
_
N2
(
c
_grid_desc
_m_n
);
using
CM0N0M1N1M2M3M4N2
GridDesc
=
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
using
C
GridDesc_
M0
_
N0
_
M1
_
N1
_
M2
_
M3
_
M4
_
N2
=
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
const
auto
c_block_cluster_adaptor
=
const
auto
block_2_ctile_map
=
GridwiseGemm
::
MakeBlock2CTileMap
(
c_grid_desc_m_n
,
M01
,
N01
);
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_m_n_grid_desc
,
M01
,
N01
);
using
C
Block
ClusterAdaptor
=
decltype
(
c_
block_
cluster_adaptor
);
using
Block
2CTileMap
=
decltype
(
block_
2_ctile_map
);
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
c_
m_n_
grid_desc
);
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
c_grid_desc
_m_n
);
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
FloatAB
,
FloatC
,
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
remove_reference_t
<
AK0MK1GridDesc
>
,
remove_reference_t
<
BK0NK1GridDesc
>
,
float
ave_time
=
0
;
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>>
;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float
ave_time
=
launch_and_time_kernel
(
kernel
,
if
(
has_main_k0_block_loop
)
nrepeat
,
{
dim3
(
grid_size
),
const
auto
kernel
=
dim3
(
BlockSize
),
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
0
,
FloatAB
,
p_a_grid
,
FloatC
,
p_b_grid
,
remove_reference_t
<
AGridDesc_K0_M_K1
>
,
p_c_grid
,
remove_reference_t
<
BGridDesc_K0_N_K
>
,
a_k0_m_k1_grid_desc
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
b_k0_n_k1_grid_desc
,
remove_reference_t
<
Block2CTileMap
>
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
true
>
;
c_block_cluster_adaptor
);
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
block_2_ctile_map
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
BGridDesc_K0_N_K
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
Block2CTileMap
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
block_2_ctile_map
);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem
a_k0_m_k1_grid_desc_dev_buf
(
sizeof
(
AK0MK1GridDesc
));
DeviceMem
a_grid_desc_k0_m_k1_dev_buf
(
sizeof
(
AGridDesc_K0_M_K1
));
DeviceMem
b_k0_n_k1_grid_desc_dev_buf
(
sizeof
(
BK0NK1GridDesc
));
DeviceMem
b_grid_desc_k0_n_k1_dev_buf
(
sizeof
(
BGridDesc_K0_N_K
));
DeviceMem
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
(
sizeof
(
CM0N0M1N1M2M3M4N2GridDesc
));
DeviceMem
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
(
DeviceMem
c_block_cluster_adaptor_dev_buf
(
sizeof
(
CBlockClusterAdaptor
));
sizeof
(
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
));
DeviceMem
block_2_ctile_map_dev_buf
(
sizeof
(
Block2CTileMap
));
a_
k0_m_k1_
grid_desc_dev_buf
.
ToDevice
(
&
a_
k0_m_k1_
grid_desc
);
a_grid_desc_
k0_m_k1_
dev_buf
.
ToDevice
(
&
a_grid_desc
_k0_m_k1
);
b_
k0_n_k1_
grid_desc_dev_buf
.
ToDevice
(
&
b_
k0_n_k1_
grid_desc
);
b_grid_desc_
k0_n_k1_
dev_buf
.
ToDevice
(
&
b_grid_desc
_k0_n_k1
);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
ToDevice
(
&
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
ToDevice
(
&
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
c_block_cluster_adaptor_dev_buf
.
ToDevice
(
&
c_block_cluster_adaptor
);
block_2_ctile_map_dev_buf
.
ToDevice
(
&
block_2_ctile_map
);
float
ave_time
=
launch_and_time_kernel
(
if
(
has_main_k0_block_loop
)
kernel
,
{
nrepeat
,
const
auto
kernel
=
dim3
(
grid_size
),
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
dim3
(
BlockSize
),
FloatAB
,
0
,
FloatC
,
p_a_grid
,
remove_reference_t
<
AGridDesc_K0_M_K1
>
,
p_b_grid
,
remove_reference_t
<
BGridDesc_K0_N_K
>
,
p_c_grid
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
cast_pointer_to_constant_address_space
(
a_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
remove_reference_t
<
Block2CTileMap
>
,
cast_pointer_to_constant_address_space
(
b_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
true
>
;
cast_pointer_to_constant_address_space
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
ave_time
=
launch_and_time_kernel
(
cast_pointer_to_constant_address_space
(
c_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_grid_desc_k0_m_k1_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_grid_desc_k0_n_k1_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
block_2_ctile_map_dev_buf
.
GetDeviceBuffer
()));
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
BGridDesc_K0_N_K
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
Block2CTileMap
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_grid_desc_k0_m_k1_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_grid_desc_k0_n_k1_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
block_2_ctile_map_dev_buf
.
GetDeviceBuffer
()));
}
}
#endif
#endif
return
ave_time
;
return
ave_time
;
}
}
...
...
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
0 → 100644
View file @
ed068043
#ifndef DRIVER_GEMM_XDLOPS_V2R4
#define DRIVER_GEMM_XDLOPS_V2R4
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r4.hpp"
template
<
ck
::
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
ck
::
InMemoryDataOperationEnum_t
CGlobalMemoryDataOperation
,
typename
ABK0MK1GridDesc
,
typename
BBK0NK1GridDesc
,
typename
CMNGridDesc
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
K1
,
ck
::
index_t
MRepeat
,
ck
::
index_t
NRepeat
,
typename
ABlockTransferThreadSliceLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
typename
BBlockTransferThreadSliceLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
typename
CThreadTransferSrcDstAccessOrder
,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferDstScalarPerVector
,
typename
AGridStepHacks
,
typename
BGridStepHacks
,
typename
CGridStepHacks
,
typename
AGridMoveSliceWindowStepHacks
,
typename
BGridMoveSliceWindowStepHacks
,
bool
CAccessOrderMRepeatNRepeat
,
bool
ABlockLdsAddExtraM
,
bool
BBlockLdsAddExtraN
>
__host__
float
driver_gemm_xdlops_v2r4
(
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
const
ABK0MK1GridDesc
&
a_b_k0_m_k1_grid_desc
,
const
BBK0NK1GridDesc
&
b_b_k0_n_k1_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
ck
::
index_t
M01
,
ck
::
index_t
N01
,
AGridStepHacks
,
BGridStepHacks
,
CGridStepHacks
,
AGridMoveSliceWindowStepHacks
,
BGridMoveSliceWindowStepHacks
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
<
BlockSize
,
FloatAB
,
FloatAcc
,
FloatC
,
CGlobalMemoryDataOperation
,
ABK0MK1GridDesc
,
BBK0NK1GridDesc
,
CMNGridDesc
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
AThreadTransferSrcResetCoordinateAfterRun
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
BThreadTransferSrcResetCoordinateAfterRun
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
AGridStepHacks
,
BGridStepHacks
,
CGridStepHacks
,
AGridMoveSliceWindowStepHacks
,
BGridMoveSliceWindowStepHacks
,
CAccessOrderMRepeatNRepeat
,
ABlockLdsAddExtraM
,
BBlockLdsAddExtraN
>
;
{
std
::
cout
<<
"a_b_k0_m_k1_grid_desc{"
<<
a_b_k0_m_k1_grid_desc
.
GetLength
(
I0
)
<<
", "
<<
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
)
<<
", "
<<
a_b_k0_m_k1_grid_desc
.
GetLength
(
I2
)
<<
", "
<<
a_b_k0_m_k1_grid_desc
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"b_b_k0_n_k1_grid_desc{"
<<
b_b_k0_n_k1_grid_desc
.
GetLength
(
I0
)
<<
", "
<<
b_b_k0_n_k1_grid_desc
.
GetLength
(
I1
)
<<
", "
<<
b_b_k0_n_k1_grid_desc
.
GetLength
(
I2
)
<<
", "
<<
b_b_k0_n_k1_grid_desc
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"c_m_n_grid_desc{ "
<<
c_m_n_grid_desc
.
GetLength
(
I0
)
<<
", "
<<
c_m_n_grid_desc
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
if
(
!
GridwiseGemm
::
CheckValidity
(
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
M01
,
N01
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r4 has invalid setting"
);
}
const
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
=
GridwiseGemm
::
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
c_m_n_grid_desc
);
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
const
auto
KBatch
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I0
);
const
auto
c_block_cluster_adaptor
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_m_n_grid_desc
,
M01
,
N01
,
KBatch
);
using
CBlockClusterAdaptor
=
decltype
(
c_block_cluster_adaptor
);
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
c_m_n_grid_desc
,
KBatch
);
{
std
::
cout
<<
"gridSize : "
<<
grid_size
<<
std
::
endl
;
}
const
auto
K0
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
float
ave_time
=
0
;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
ABK0MK1GridDesc
>
,
remove_reference_t
<
BBK0NK1GridDesc
>
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
ABK0MK1GridDesc
>
,
remove_reference_t
<
BBK0NK1GridDesc
>
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem
a_b_k0_m_k1_grid_desc_dev_buf
(
sizeof
(
ABK0MK1GridDesc
));
DeviceMem
b_b_k0_n_k1_grid_desc_dev_buf
(
sizeof
(
BBK0NK1GridDesc
));
DeviceMem
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
(
sizeof
(
CM0N0M1N1M2M3M4N2GridDesc
));
DeviceMem
c_block_cluster_adaptor_dev_buf
(
sizeof
(
CBlockClusterAdaptor
));
a_b_k0_m_k1_grid_desc_dev_buf
.
ToDevice
(
&
a_b_k0_m_k1_grid_desc
);
b_b_k0_n_k1_grid_desc_dev_buf
.
ToDevice
(
&
b_b_k0_n_k1_grid_desc
);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
ToDevice
(
&
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
c_block_cluster_adaptor_dev_buf
.
ToDevice
(
&
c_block_cluster_adaptor
);
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
ABK0MK1GridDesc
>
,
remove_reference_t
<
BBK0NK1GridDesc
>
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_b_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_b_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
ABK0MK1GridDesc
>
,
remove_reference_t
<
BBK0NK1GridDesc
>
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_b_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_b_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
#endif
return
ave_time
;
}
#endif
host/driver_offline/src/conv_bwd_driver_offline.cpp
View file @
ed068043
...
@@ -11,7 +11,6 @@
...
@@ -11,7 +11,6 @@
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_tensor_generator.hpp"
#include "conv_common.hpp"
#include "conv_common.hpp"
#include "host_conv_bwd_data.hpp"
#include "device_tensor.hpp"
#include "device_tensor.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp"
...
@@ -21,12 +20,153 @@
...
@@ -21,12 +20,153 @@
#define USE_CONV_BWD_V4R1_XDL_NHWC 0
#define USE_CONV_BWD_V4R1_XDL_NHWC 0
#define USE_CONV_BWD_V4R1R2_XDL_NHWC 1
#define USE_CONV_BWD_V4R1R2_XDL_NHWC 1
enum
ConvTensorLayout
{
NCHW
,
NHWC
,
CHWN
,
NCHWc
,
NHWCc
};
enum
ConvBackwardDataAlgo
enum
ConvBackwardDataAlgo
{
{
V4R1XDLNHWC
,
// 0
V4R1XDLNHWC
,
// 0
V4R1R2XDLNHWC
,
// 1
V4R1R2XDLNHWC
,
// 1
};
};
template
<
typename
TIn
,
typename
TWei
,
typename
TOut
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
host_convolution_backward_data
(
Tensor
<
TIn
>&
in
,
const
Tensor
<
TWei
>&
wei
,
const
Tensor
<
TOut
>&
out
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
/* in_right_pads */
,
const
ConvTensorLayout
layout
=
ConvTensorLayout
::
NCHW
)
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
auto
f_nchw
=
[
&
](
auto
n
,
auto
c
,
auto
hi
,
auto
wi
)
{
std
::
size_t
K
=
wei
.
mDesc
.
GetLengths
()[
I0
];
std
::
size_t
Y
=
wei
.
mDesc
.
GetLengths
()[
I2
];
std
::
size_t
X
=
wei
.
mDesc
.
GetLengths
()[
I3
];
std
::
size_t
Ho
=
out
.
mDesc
.
GetLengths
()[
I2
];
std
::
size_t
Wo
=
out
.
mDesc
.
GetLengths
()[
I3
];
double
v
=
0
;
for
(
int
y
=
0
;
y
<
Y
;
++
y
)
{
int
h_tmp
=
hi
+
in_left_pads
[
I0
]
-
y
*
conv_dilations
[
I0
];
if
(
h_tmp
%
conv_strides
[
I0
]
==
0
)
{
int
ho
=
h_tmp
/
conv_strides
[
I0
];
if
(
ho
>=
0
&&
ho
<
Ho
)
{
for
(
int
x
=
0
;
x
<
X
;
++
x
)
{
int
w_tmp
=
wi
+
in_left_pads
[
I1
]
-
x
*
conv_dilations
[
I1
];
if
(
w_tmp
%
conv_strides
[
I1
]
==
0
)
{
int
wo
=
w_tmp
/
conv_strides
[
I1
];
if
(
wo
>=
0
&&
wo
<
Wo
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
v
+=
out
(
n
,
k
,
ho
,
wo
)
*
wei
(
k
,
c
,
y
,
x
);
}
}
}
}
}
}
}
in
(
n
,
c
,
hi
,
wi
)
=
v
;
};
auto
f_nhwc
=
[
&
](
auto
n
,
auto
hi
,
auto
wi
,
auto
c
)
{
std
::
size_t
K
=
wei
.
mDesc
.
GetLengths
()[
I0
];
std
::
size_t
Y
=
wei
.
mDesc
.
GetLengths
()[
I1
];
std
::
size_t
X
=
wei
.
mDesc
.
GetLengths
()[
I2
];
std
::
size_t
Ho
=
out
.
mDesc
.
GetLengths
()[
I1
];
std
::
size_t
Wo
=
out
.
mDesc
.
GetLengths
()[
I2
];
double
v
=
0
;
for
(
int
y
=
0
;
y
<
Y
;
++
y
)
{
int
h_tmp
=
hi
+
in_left_pads
[
I0
]
-
y
*
conv_dilations
[
I0
];
if
(
h_tmp
%
conv_strides
[
I0
]
==
0
)
{
int
ho
=
h_tmp
/
conv_strides
[
I0
];
if
(
ho
>=
0
&&
ho
<
Ho
)
{
for
(
int
x
=
0
;
x
<
X
;
++
x
)
{
int
w_tmp
=
wi
+
in_left_pads
[
I1
]
-
x
*
conv_dilations
[
I1
];
if
(
w_tmp
%
conv_strides
[
I1
]
==
0
)
{
int
wo
=
w_tmp
/
conv_strides
[
I1
];
if
(
wo
>=
0
&&
wo
<
Wo
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
v
+=
out
(
n
,
ho
,
wo
,
k
)
*
wei
(
k
,
y
,
x
,
c
);
}
}
}
}
}
}
}
in
(
n
,
hi
,
wi
,
c
)
=
v
;
};
if
(
layout
==
ConvTensorLayout
::
NCHW
)
{
make_ParallelTensorFunctor
(
f_nchw
,
in
.
mDesc
.
GetLengths
()[
0
],
in
.
mDesc
.
GetLengths
()[
1
],
in
.
mDesc
.
GetLengths
()[
2
],
in
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
}
else
if
(
layout
==
ConvTensorLayout
::
NHWC
)
{
make_ParallelTensorFunctor
(
f_nhwc
,
in
.
mDesc
.
GetLengths
()[
0
],
in
.
mDesc
.
GetLengths
()[
1
],
in
.
mDesc
.
GetLengths
()[
2
],
in
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
}
else
{
throw
std
::
runtime_error
(
"wrong! not supported layout"
);
}
}
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
using
namespace
ck
;
using
namespace
ck
;
...
@@ -324,14 +464,14 @@ int main(int argc, char* argv[])
...
@@ -324,14 +464,14 @@ int main(int argc, char* argv[])
if
(
do_verification
)
if
(
do_verification
)
{
{
host_
direct_
convolution_backward_data
(
in_host
,
host_convolution_backward_data
(
in_host
,
wei
,
wei
,
out
,
out
,
make_tuple
(
conv_stride_h
,
conv_stride_w
),
make_tuple
(
conv_stride_h
,
conv_stride_w
),
make_tuple
(
conv_dilation_h
,
conv_dilation_w
),
make_tuple
(
conv_dilation_h
,
conv_dilation_w
),
make_tuple
(
in_left_pad_h
,
in_left_pad_w
),
make_tuple
(
in_left_pad_h
,
in_left_pad_w
),
make_tuple
(
in_right_pad_h
,
in_right_pad_w
),
make_tuple
(
in_right_pad_h
,
in_right_pad_w
),
layout
);
layout
);
check_error
(
in_host
,
in_device
);
check_error
(
in_host
,
in_device
);
...
...
host/driver_offline/src/conv_fwd_driver_offline.cpp
View file @
ed068043
...
@@ -11,7 +11,6 @@
...
@@ -11,7 +11,6 @@
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_tensor_generator.hpp"
#include "conv_common.hpp"
#include "conv_common.hpp"
#include "host_conv.hpp"
#include "device_tensor.hpp"
#include "device_tensor.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
...
@@ -26,6 +25,15 @@
...
@@ -26,6 +25,15 @@
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
enum
ConvTensorLayout
{
NCHW
,
NHWC
,
CHWN
,
NCHWc
,
NHWCc
};
enum
ConvForwardAlgo
enum
ConvForwardAlgo
{
{
V4R4NCHW
,
// 0
V4R4NCHW
,
// 0
...
@@ -35,6 +43,93 @@ enum ConvForwardAlgo
...
@@ -35,6 +43,93 @@ enum ConvForwardAlgo
V4R4R4XDLNHWC
// 4
V4R4R4XDLNHWC
// 4
};
};
template
<
typename
TIn
,
typename
TWei
,
typename
TOut
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
host_convolution_forward
(
const
Tensor
<
TIn
>&
in
,
const
Tensor
<
TWei
>&
wei
,
Tensor
<
TOut
>&
out
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
,
const
ConvTensorLayout
layout
=
ConvTensorLayout
::
NCHW
)
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
auto
f_nchw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
double
v
=
0
;
for
(
int
c
=
0
;
c
<
wei
.
mDesc
.
GetLengths
()[
1
];
++
c
)
{
for
(
int
y
=
0
;
y
<
wei
.
mDesc
.
GetLengths
()[
2
];
++
y
)
{
int
hi
=
ho
*
conv_strides
[
I0
]
+
y
*
conv_dilations
[
I0
]
-
in_left_pads
[
I0
];
for
(
int
x
=
0
;
x
<
wei
.
mDesc
.
GetLengths
()[
3
];
++
x
)
{
int
wi
=
wo
*
conv_strides
[
I1
]
+
x
*
conv_dilations
[
I1
]
-
in_left_pads
[
I1
];
if
(
hi
>=
0
&&
hi
<
in
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
wi
<
in
.
mDesc
.
GetLengths
()[
3
])
{
v
+=
static_cast
<
const
double
>
(
in
(
n
,
c
,
hi
,
wi
))
*
static_cast
<
const
double
>
(
wei
(
k
,
c
,
y
,
x
));
}
}
}
}
out
(
n
,
k
,
ho
,
wo
)
=
v
;
};
auto
f_nhwc
=
[
&
](
auto
n
,
auto
ho
,
auto
wo
,
auto
k
)
{
double
v
=
0
;
for
(
int
c
=
0
;
c
<
wei
.
mDesc
.
GetLengths
()[
3
];
++
c
)
{
for
(
int
y
=
0
;
y
<
wei
.
mDesc
.
GetLengths
()[
1
];
++
y
)
{
int
hi
=
ho
*
conv_strides
[
I0
]
+
y
*
conv_dilations
[
I0
]
-
in_left_pads
[
I0
];
for
(
int
x
=
0
;
x
<
wei
.
mDesc
.
GetLengths
()[
2
];
++
x
)
{
int
wi
=
wo
*
conv_strides
[
I1
]
+
x
*
conv_dilations
[
I1
]
-
in_left_pads
[
I1
];
if
(
hi
>=
0
&&
hi
<
in
.
mDesc
.
GetLengths
()[
1
]
&&
wi
>=
0
&&
wi
<
in
.
mDesc
.
GetLengths
()[
2
])
{
v
+=
static_cast
<
const
double
>
(
in
(
n
,
hi
,
wi
,
c
))
*
static_cast
<
const
double
>
(
wei
(
k
,
y
,
x
,
c
));
}
}
}
}
out
(
n
,
ho
,
wo
,
k
)
=
v
;
};
if
(
layout
==
ConvTensorLayout
::
NCHW
)
{
make_ParallelTensorFunctor
(
f_nchw
,
out
.
mDesc
.
GetLengths
()[
0
],
out
.
mDesc
.
GetLengths
()[
1
],
out
.
mDesc
.
GetLengths
()[
2
],
out
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
}
else
if
(
layout
==
ConvTensorLayout
::
NHWC
)
{
make_ParallelTensorFunctor
(
f_nhwc
,
out
.
mDesc
.
GetLengths
()[
0
],
out
.
mDesc
.
GetLengths
()[
1
],
out
.
mDesc
.
GetLengths
()[
2
],
out
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
}
else
{
throw
std
::
runtime_error
(
"wrong! not supported layout"
);
}
}
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
using
namespace
ck
;
using
namespace
ck
;
...
@@ -395,14 +490,14 @@ int main(int argc, char* argv[])
...
@@ -395,14 +490,14 @@ int main(int argc, char* argv[])
if
(
do_verification
)
if
(
do_verification
)
{
{
host_
direct_
convolution
(
in
,
host_convolution
_forward
(
in
,
wei
,
wei
,
out_host
,
out_host
,
make_tuple
(
conv_stride_h
,
conv_stride_w
),
make_tuple
(
conv_stride_h
,
conv_stride_w
),
make_tuple
(
conv_dilation_h
,
conv_dilation_w
),
make_tuple
(
conv_dilation_h
,
conv_dilation_w
),
make_tuple
(
in_left_pad_h
,
in_left_pad_w
),
make_tuple
(
in_left_pad_h
,
in_left_pad_w
),
make_tuple
(
in_right_pad_h
,
in_right_pad_w
),
make_tuple
(
in_right_pad_h
,
in_right_pad_w
),
layout
);
layout
);
check_error
(
out_host
,
out_device
);
check_error
(
out_host
,
out_device
);
...
...
host/driver_offline/src/conv_wrw_driver_offline.cpp
View file @
ed068043
...
@@ -11,18 +11,124 @@
...
@@ -11,18 +11,124 @@
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_tensor_generator.hpp"
#include "conv_common.hpp"
#include "conv_common.hpp"
#include "host_conv_bwd_weight.hpp"
#include "device_tensor.hpp"
#include "device_tensor.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp"
enum
ConvTensorLayout
{
NCHW
,
NHWC
,
CHWN
,
NCHWc
,
NHWCc
};
#define USE_DYNAMIC_MODE 1
#define USE_DYNAMIC_MODE 1
#define USE_CONV_WRW_V4R4R2_XDL_NCHW 1
#define USE_CONV_WRW_V4R4R2_XDL_NCHW 0
#define USE_CONV_WRW_V4R4R4_XDL_NHWC 0
#define USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW 0
#define USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC 0
#define USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC 1
enum
ConvBackwardWeightAlgo
enum
ConvBackwardWeightAlgo
{
{
V4R4R2XDLNCHW
,
V4R4R2XDLNCHW
,
// 0
V4R4R4XDLNHWC
,
// 1
V4R4R2XDLATOMICNCHW
,
// 2
V4R4R4XDLATOMICNHWC
,
// 3
V4R4R5XDLATOMICNHWC
,
// 4
};
};
template
<
typename
TOut
,
typename
TIn
,
typename
TWei
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
host_convolution_backward_weight
(
const
Tensor
<
TOut
>&
out
,
const
Tensor
<
TIn
>&
in
,
Tensor
<
TWei
>&
wei
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
,
const
ConvTensorLayout
layout
=
ConvTensorLayout
::
NCHW
)
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
auto
f_kcyx
=
[
&
](
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
double
v
=
0
;
for
(
int
n
=
0
;
n
<
out
.
mDesc
.
GetLengths
()[
0
];
++
n
)
{
for
(
int
ho
=
0
;
ho
<
out
.
mDesc
.
GetLengths
()[
2
];
++
ho
)
{
int
hi
=
ho
*
conv_strides
[
I0
]
+
y
*
conv_dilations
[
I0
]
-
in_left_pads
[
I0
];
for
(
int
wo
=
0
;
wo
<
out
.
mDesc
.
GetLengths
()[
3
];
++
wo
)
{
int
wi
=
wo
*
conv_strides
[
I1
]
+
x
*
conv_dilations
[
I1
]
-
in_left_pads
[
I1
];
if
(
hi
>=
0
&&
hi
<
in
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
wi
<
in
.
mDesc
.
GetLengths
()[
3
])
{
v
+=
static_cast
<
const
double
>
(
in
(
n
,
c
,
hi
,
wi
))
*
static_cast
<
const
double
>
(
out
(
n
,
k
,
ho
,
wo
));
}
}
}
}
wei
(
k
,
c
,
y
,
x
)
=
v
;
};
auto
f_kyxc
=
[
&
](
auto
k
,
auto
y
,
auto
x
,
auto
c
)
{
double
v
=
0
;
for
(
int
n
=
0
;
n
<
out
.
mDesc
.
GetLengths
()[
0
];
++
n
)
{
for
(
int
ho
=
0
;
ho
<
out
.
mDesc
.
GetLengths
()[
1
];
++
ho
)
{
int
hi
=
ho
*
conv_strides
[
I0
]
+
y
*
conv_dilations
[
I0
]
-
in_left_pads
[
I0
];
for
(
int
wo
=
0
;
wo
<
out
.
mDesc
.
GetLengths
()[
2
];
++
wo
)
{
int
wi
=
wo
*
conv_strides
[
I1
]
+
x
*
conv_dilations
[
I1
]
-
in_left_pads
[
I1
];
if
(
hi
>=
0
&&
hi
<
in
.
mDesc
.
GetLengths
()[
1
]
&&
wi
>=
0
&&
wi
<
in
.
mDesc
.
GetLengths
()[
2
])
{
v
+=
static_cast
<
const
double
>
(
in
(
n
,
hi
,
wi
,
c
))
*
static_cast
<
const
double
>
(
out
(
n
,
ho
,
wo
,
k
));
}
}
}
}
wei
(
k
,
y
,
x
,
c
)
=
v
;
};
if
(
layout
==
ConvTensorLayout
::
NCHW
)
{
make_ParallelTensorFunctor
(
f_kcyx
,
wei
.
mDesc
.
GetLengths
()[
0
],
wei
.
mDesc
.
GetLengths
()[
1
],
wei
.
mDesc
.
GetLengths
()[
2
],
wei
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
}
else
if
(
layout
==
ConvTensorLayout
::
NHWC
)
{
make_ParallelTensorFunctor
(
f_kyxc
,
wei
.
mDesc
.
GetLengths
()[
0
],
wei
.
mDesc
.
GetLengths
()[
1
],
wei
.
mDesc
.
GetLengths
()[
2
],
wei
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
}
else
{
throw
std
::
runtime_error
(
"wrong! not supported layout"
);
}
}
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
using
namespace
ck
;
using
namespace
ck
;
...
@@ -37,10 +143,11 @@ int main(int argc, char* argv[])
...
@@ -37,10 +143,11 @@ int main(int argc, char* argv[])
#if USE_DYNAMIC_MODE
#if USE_DYNAMIC_MODE
// dynamic mode
// dynamic mode
if
(
argc
!=
2
2
)
if
(
argc
!=
2
3
)
{
{
printf
(
"arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat
\n
"
);
printf
(
"arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat
\n
"
);
printf
(
"rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx
\n
"
);
printf
(
"rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx
\n
"
);
printf
(
"additional: desired_grid_size
\n
"
);
exit
(
1
);
exit
(
1
);
}
}
...
@@ -68,6 +175,8 @@ int main(int argc, char* argv[])
...
@@ -68,6 +175,8 @@ int main(int argc, char* argv[])
const
index_t
in_right_pad_h
=
std
::
stoi
(
argv
[
20
]);
const
index_t
in_right_pad_h
=
std
::
stoi
(
argv
[
20
]);
const
index_t
in_right_pad_w
=
std
::
stoi
(
argv
[
21
]);
const
index_t
in_right_pad_w
=
std
::
stoi
(
argv
[
21
]);
const
index_t
desired_grid_size
=
std
::
stoi
(
argv
[
22
]);
const
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
const
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
const
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
const
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
...
@@ -114,16 +223,19 @@ int main(int argc, char* argv[])
...
@@ -114,16 +223,19 @@ int main(int argc, char* argv[])
#if 0
#if 0
using in_data_t = float;
using in_data_t = float;
using wei_data_t = float;
using acc_data_t = float;
using acc_data_t = float;
using out_data_t = float;
using out_data_t = float;
#elif
1
#elif
1
using
in_data_t
=
half_t
;
using
in_data_t
=
half_t
;
using
acc_data_t
=
float
;
using
out_data_t
=
half_t
;
using
out_data_t
=
half_t
;
using
acc_data_t
=
float
;
using
wei_data_t
=
float
;
#elif 1
#elif 1
using
in_data_t
=
int8_t
;
using
in_data_t
=
int8_t
;
using
acc_data_t
=
int32_t
;
using
out_data_t
=
int8_t
;
using
out_data_t
=
int8_t
;
using
acc_data_t
=
int32_t
;
using
wei_data_t
=
int8_t
;
#endif
#endif
std
::
vector
<
std
::
size_t
>
in_lengths_host
(
4
),
wei_lengths_host
(
4
),
out_lengths_host
(
4
);
std
::
vector
<
std
::
size_t
>
in_lengths_host
(
4
),
wei_lengths_host
(
4
),
out_lengths_host
(
4
);
...
@@ -164,8 +276,8 @@ int main(int argc, char* argv[])
...
@@ -164,8 +276,8 @@ int main(int argc, char* argv[])
}
}
Tensor
<
in_data_t
>
in
(
in_lengths_host
);
Tensor
<
in_data_t
>
in
(
in_lengths_host
);
Tensor
<
i
n
_data_t
>
wei_device
(
wei_lengths_host
);
Tensor
<
we
i_data_t
>
wei_device
(
wei_lengths_host
);
Tensor
<
out
_data_t
>
wei_host
(
wei_lengths_host
);
Tensor
<
wei
_data_t
>
wei_host
(
wei_lengths_host
);
Tensor
<
out_data_t
>
out
(
out_lengths_host
);
Tensor
<
out_data_t
>
out
(
out_lengths_host
);
std
::
cout
<<
"layout: "
<<
layout
<<
std
::
endl
;
std
::
cout
<<
"layout: "
<<
layout
<<
std
::
endl
;
...
@@ -231,6 +343,26 @@ int main(int argc, char* argv[])
...
@@ -231,6 +343,26 @@ int main(int argc, char* argv[])
in_right_pads_dev
);
in_right_pads_dev
);
};
};
auto
f_make_for_device_nhwc
=
[
&
]()
{
const
auto
in_lengths_dev
=
make_tuple
(
N
,
Hi
,
Wi
,
C
);
const
auto
wei_lengths_dev
=
make_tuple
(
K
,
Y
,
X
,
C
);
const
auto
out_lengths_dev
=
make_tuple
(
N
,
Ho
,
Wo
,
K
);
const
auto
conv_strides_dev
=
make_tuple
(
conv_stride_h
,
conv_stride_w
);
const
auto
conv_dilations_dev
=
make_tuple
(
conv_dilation_h
,
conv_dilation_w
);
const
auto
in_left_pads_dev
=
make_tuple
(
in_left_pad_h
,
in_left_pad_w
);
const
auto
in_right_pads_dev
=
make_tuple
(
in_right_pad_h
,
in_right_pad_w
);
return
make_tuple
(
in_lengths_dev
,
wei_lengths_dev
,
out_lengths_dev
,
conv_strides_dev
,
conv_dilations_dev
,
in_left_pads_dev
,
in_right_pads_dev
);
};
// set zero to wei_device
wei_device
.
GenerateTensorValue
(
GeneratorTensor_0
{},
num_thread
);
#if USE_CONV_WRW_V4R4R2_XDL_NCHW
#if USE_CONV_WRW_V4R4R2_XDL_NCHW
if
(
algo
==
ConvBackwardWeightAlgo
::
V4R4R2XDLNCHW
)
if
(
algo
==
ConvBackwardWeightAlgo
::
V4R4R2XDLNCHW
)
{
{
...
@@ -242,6 +374,7 @@ int main(int argc, char* argv[])
...
@@ -242,6 +374,7 @@ int main(int argc, char* argv[])
const
auto
tmp
=
f_make_for_device_nchw
();
const
auto
tmp
=
f_make_for_device_nchw
();
device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
<
in_data_t
,
device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
<
in_data_t
,
wei_data_t
,
acc_data_t
,
acc_data_t
,
out_data_t
>
(
out_data_t
>
(
tmp
[
I0
],
tmp
[
I0
],
...
@@ -258,16 +391,131 @@ int main(int argc, char* argv[])
...
@@ -258,16 +391,131 @@ int main(int argc, char* argv[])
}
}
#endif
#endif
#if USE_CONV_WRW_V4R4R4_XDL_NHWC
if
(
algo
==
ConvBackwardWeightAlgo
::
V4R4R4XDLNHWC
)
{
if
(
layout
!=
ConvTensorLayout
::
NHWC
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
const
auto
tmp
=
f_make_for_device_nhwc
();
device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
<
in_data_t
,
wei_data_t
,
acc_data_t
,
out_data_t
>
(
tmp
[
I0
],
tmp
[
I1
],
tmp
[
I2
],
tmp
[
I3
],
tmp
[
I4
],
tmp
[
I5
],
tmp
[
I6
],
in
,
wei_device
,
out
,
nrepeat
);
}
#endif
#if USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW
if
(
algo
==
ConvBackwardWeightAlgo
::
V4R4R2XDLATOMICNCHW
)
{
if
(
layout
!=
ConvTensorLayout
::
NCHW
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
const
auto
tmp
=
f_make_for_device_nchw
();
device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw
<
in_data_t
,
wei_data_t
,
acc_data_t
,
out_data_t
>
(
tmp
[
I0
],
tmp
[
I1
],
tmp
[
I2
],
tmp
[
I3
],
tmp
[
I4
],
tmp
[
I5
],
tmp
[
I6
],
in
,
wei_device
,
out
,
desired_grid_size
,
nrepeat
);
}
#endif
#if USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC
if
(
algo
==
ConvBackwardWeightAlgo
::
V4R4R4XDLATOMICNHWC
)
{
if
(
layout
!=
ConvTensorLayout
::
NHWC
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
const
auto
tmp
=
f_make_for_device_nhwc
();
device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk
<
in_data_t
,
wei_data_t
,
acc_data_t
,
out_data_t
>
(
tmp
[
I0
],
tmp
[
I1
],
tmp
[
I2
],
tmp
[
I3
],
tmp
[
I4
],
tmp
[
I5
],
tmp
[
I6
],
in
,
wei_device
,
out
,
desired_grid_size
,
nrepeat
);
}
#endif
#if USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC
if
(
algo
==
ConvBackwardWeightAlgo
::
V4R4R5XDLATOMICNHWC
)
{
if
(
layout
!=
ConvTensorLayout
::
NHWC
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
const
auto
tmp
=
f_make_for_device_nhwc
();
device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk
<
in_data_t
,
wei_data_t
,
acc_data_t
,
out_data_t
>
(
tmp
[
I0
],
tmp
[
I1
],
tmp
[
I2
],
tmp
[
I3
],
tmp
[
I4
],
tmp
[
I5
],
tmp
[
I6
],
in
,
wei_device
,
out
,
desired_grid_size
,
nrepeat
);
}
#endif
if
(
do_verification
)
if
(
do_verification
)
{
{
host_
direct_
convolution_backward_weight
s
(
out
,
host_convolution_backward_weight
(
out
,
in
,
in
,
wei_host
,
wei_host
,
make_tuple
(
conv_stride_h
,
conv_stride_w
),
make_tuple
(
conv_stride_h
,
conv_stride_w
),
make_tuple
(
conv_dilation_h
,
conv_dilation_w
),
make_tuple
(
conv_dilation_h
,
conv_dilation_w
),
make_tuple
(
in_left_pad_h
,
in_left_pad_w
),
make_tuple
(
in_left_pad_h
,
in_left_pad_w
),
make_tuple
(
in_right_pad_h
,
in_right_pad_w
),
make_tuple
(
in_right_pad_h
,
in_right_pad_w
),
layout
);
layout
);
check_error
(
wei_host
,
wei_device
);
check_error
(
wei_host
,
wei_device
);
...
...
host/host_tensor/include/conv_common.hpp
View file @
ed068043
...
@@ -3,15 +3,6 @@
...
@@ -3,15 +3,6 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
enum
ConvTensorLayout
{
NCHW
,
NHWC
,
CHWN
,
NCHWc
,
NHWCc
};
template
<
typename
...
InDesc
,
template
<
typename
...
InDesc
,
typename
...
WeiDesc
,
typename
...
WeiDesc
,
typename
ConvStrides
,
typename
ConvStrides
,
...
...
host/host_tensor/include/device.hpp
View file @
ed068043
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#define DEVICE_HPP
#define DEVICE_HPP
#include <memory>
#include <memory>
#include <functional>
#include <thread>
#include <thread>
#include <chrono>
#include <chrono>
#include "hip/hip_runtime.h"
#include "hip/hip_runtime.h"
...
@@ -80,5 +81,4 @@ float launch_and_time_kernel(
...
@@ -80,5 +81,4 @@ float launch_and_time_kernel(
return
timer
.
GetElapsedTime
()
/
nrepeat
;
return
timer
.
GetElapsedTime
()
/
nrepeat
;
}
}
#endif
#endif
host/host_tensor/include/host_conv.hpp
View file @
ed068043
#pragma once
#pragma once
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "conv_common.hpp"
template
<
typename
T
>
inline
auto
activ
(
T
v
,
const
ck
::
index_t
activ_type
)
{
const
T
alpha
=
0.3
;
switch
(
activ_type
)
{
case
0
:
return
v
;
case
1
:
return
(
v
>=
0
?
v
:
alpha
*
v
);
case
2
:
return
(
1
/
(
1
+
exp
(
-
v
)));
default:
throw
std
::
runtime_error
(
"unsupported activ type"
);
break
;
}
}
template
<
typename
TIn
,
template
<
typename
TIn
,
typename
TWei
,
typename
TWei
,
...
@@ -21,20 +9,16 @@ template <typename TIn,
...
@@ -21,20 +9,16 @@ template <typename TIn,
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
>
typename
InRightPads
>
void
host_direct_convolution
(
const
Tensor
<
TIn
>&
in
,
void
host_conv_nchw_kcyx_nkhw
(
const
Tensor
<
TIn
>&
in
,
const
Tensor
<
TWei
>&
wei
,
const
Tensor
<
TWei
>&
wei
,
Tensor
<
TOut
>&
out
,
Tensor
<
TOut
>&
out
,
const
ConvStrides
&
conv_strides
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
,
const
InRightPads
&
)
const
ConvTensorLayout
layout
=
ConvTensorLayout
::
NCHW
,
const
ck
::
index_t
activ_type
=
0
)
{
{
using
namespace
ck
;
constexpr
auto
I0
=
ck
::
Number
<
0
>
{};
constexpr
auto
I1
=
ck
::
Number
<
1
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
auto
f_nchw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
auto
f_nchw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
double
v
=
0
;
double
v
=
0
;
...
@@ -55,495 +39,12 @@ void host_direct_convolution(const Tensor<TIn>& in,
...
@@ -55,495 +39,12 @@ void host_direct_convolution(const Tensor<TIn>& in,
}
}
}
}
}
}
out
(
n
,
k
,
ho
,
wo
)
=
activ
(
v
,
activ_type
);
out
(
n
,
k
,
ho
,
wo
)
=
v
;
};
auto
f_nhwc
=
[
&
](
auto
n
,
auto
ho
,
auto
wo
,
auto
k
)
{
double
v
=
0
;
for
(
int
c
=
0
;
c
<
wei
.
mDesc
.
GetLengths
()[
3
];
++
c
)
{
for
(
int
y
=
0
;
y
<
wei
.
mDesc
.
GetLengths
()[
1
];
++
y
)
{
int
hi
=
ho
*
conv_strides
[
I0
]
+
y
*
conv_dilations
[
I0
]
-
in_left_pads
[
I0
];
for
(
int
x
=
0
;
x
<
wei
.
mDesc
.
GetLengths
()[
2
];
++
x
)
{
int
wi
=
wo
*
conv_strides
[
I1
]
+
x
*
conv_dilations
[
I1
]
-
in_left_pads
[
I1
];
if
(
hi
>=
0
&&
hi
<
in
.
mDesc
.
GetLengths
()[
1
]
&&
wi
>=
0
&&
wi
<
in
.
mDesc
.
GetLengths
()[
2
])
{
v
+=
static_cast
<
const
double
>
(
in
(
n
,
hi
,
wi
,
c
))
*
static_cast
<
const
double
>
(
wei
(
k
,
y
,
x
,
c
));
}
}
}
}
out
(
n
,
ho
,
wo
,
k
)
=
activ
(
v
,
activ_type
);
};
if
(
layout
==
ConvTensorLayout
::
NCHW
)
{
make_ParallelTensorFunctor
(
f_nchw
,
out
.
mDesc
.
GetLengths
()[
0
],
out
.
mDesc
.
GetLengths
()[
1
],
out
.
mDesc
.
GetLengths
()[
2
],
out
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
}
else
if
(
layout
==
ConvTensorLayout
::
NHWC
)
{
make_ParallelTensorFunctor
(
f_nhwc
,
out
.
mDesc
.
GetLengths
()[
0
],
out
.
mDesc
.
GetLengths
()[
1
],
out
.
mDesc
.
GetLengths
()[
2
],
out
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
}
else
{
throw
std
::
runtime_error
(
"wrong! not supported layout"
);
}
}
template
<
typename
TIn
,
typename
TWei
,
typename
TOut
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
host_direct_convolution_nchwc
(
const
Tensor
<
TIn
>&
in
,
const
Tensor
<
TWei
>&
wei
,
const
Tensor
<
TOut
>&
bias
,
Tensor
<
TOut
>&
out
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
,
const
ck
::
index_t
activ_type
=
0
)
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
auto
f_nchw
=
[
&
](
auto
n
,
auto
k0
,
auto
ho
,
auto
wo
,
auto
k1
)
{
double
v
=
0
;
const
int
k
=
k0
*
out
.
mDesc
.
GetLengths
()[
4
]
+
k1
;
for
(
int
c0
=
0
;
c0
<
wei
.
mDesc
.
GetLengths
()[
1
];
++
c0
)
{
for
(
int
y
=
0
;
y
<
wei
.
mDesc
.
GetLengths
()[
2
];
++
y
)
{
int
hi
=
ho
*
conv_strides
[
I0
]
+
y
*
conv_dilations
[
I0
]
-
in_left_pads
[
I0
];
for
(
int
x
=
0
;
x
<
wei
.
mDesc
.
GetLengths
()[
3
];
++
x
)
{
int
wi
=
wo
*
conv_strides
[
I1
]
+
x
*
conv_dilations
[
I1
]
-
in_left_pads
[
I1
];
if
(
hi
>=
0
&&
hi
<
in
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
wi
<
in
.
mDesc
.
GetLengths
()[
3
])
{
for
(
int
c1
=
0
;
c1
<
wei
.
mDesc
.
GetLengths
()[
4
];
++
c1
)
{
v
+=
static_cast
<
const
double
>
(
in
(
n
,
c0
,
hi
,
wi
,
c1
))
*
static_cast
<
const
double
>
(
wei
(
k
,
c0
,
y
,
x
,
c1
));
}
}
}
}
}
v
+=
bias
(
k0
,
k1
);
out
(
n
,
k0
,
ho
,
wo
,
k1
)
=
activ
(
v
,
activ_type
);
};
};
make_ParallelTensorFunctor
(
f_nchw
,
make_ParallelTensorFunctor
(
f_nchw
,
out
.
mDesc
.
GetLengths
()[
0
],
out
.
mDesc
.
GetLengths
()[
0
],
out
.
mDesc
.
GetLengths
()[
1
],
out
.
mDesc
.
GetLengths
()[
1
],
out
.
mDesc
.
GetLengths
()[
2
],
out
.
mDesc
.
GetLengths
()[
2
],
out
.
mDesc
.
GetLengths
()[
3
],
out
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
out
.
mDesc
.
GetLengths
()[
4
])(
std
::
thread
::
hardware_concurrency
());
}
template
<
typename
TIn
,
typename
TWei
,
typename
TOut
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
host_direct_convolution_add_nchwc
(
const
Tensor
<
TIn
>&
in
,
const
Tensor
<
TWei
>&
wei
,
const
Tensor
<
TOut
>&
add
,
const
Tensor
<
TOut
>&
bias
,
Tensor
<
TOut
>&
add_host
,
Tensor
<
TOut
>&
out_host
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
,
const
ck
::
index_t
activ_type
=
0
)
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
auto
f_nchw
=
[
&
](
auto
n
,
auto
k0
,
auto
ho
,
auto
wo
,
auto
k1
)
{
double
v
=
0
;
auto
k
=
k0
*
out_host
.
mDesc
.
GetLengths
()[
4
]
+
k1
;
for
(
int
c0
=
0
;
c0
<
wei
.
mDesc
.
GetLengths
()[
1
];
++
c0
)
{
for
(
int
y
=
0
;
y
<
wei
.
mDesc
.
GetLengths
()[
2
];
++
y
)
{
int
hi
=
ho
*
conv_strides
[
I0
]
+
y
*
conv_dilations
[
I0
]
-
in_left_pads
[
I0
];
for
(
int
x
=
0
;
x
<
wei
.
mDesc
.
GetLengths
()[
3
];
++
x
)
{
int
wi
=
wo
*
conv_strides
[
I1
]
+
x
*
conv_dilations
[
I1
]
-
in_left_pads
[
I1
];
if
(
hi
>=
0
&&
hi
<
in
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
wi
<
in
.
mDesc
.
GetLengths
()[
3
])
{
for
(
int
c1
=
0
;
c1
<
wei
.
mDesc
.
GetLengths
()[
4
];
++
c1
)
{
v
+=
static_cast
<
const
double
>
(
in
(
n
,
c0
,
hi
,
wi
,
c1
))
*
static_cast
<
const
double
>
(
wei
(
k
,
c0
,
y
,
x
,
c1
));
}
}
}
}
}
v
+=
bias
(
k0
,
k1
);
v
=
activ
(
v
,
activ_type
);
const
int
hox2
=
ho
*
2
;
const
int
wox2
=
wo
*
2
;
out_host
(
n
,
k0
,
ho
,
wo
,
k1
)
=
v
;
add_host
(
n
,
k0
,
hox2
,
wox2
,
k1
)
=
v
+
add
(
n
,
k0
,
hox2
,
wox2
,
k1
);
add_host
(
n
,
k0
,
hox2
,
wox2
+
1
,
k1
)
=
v
+
add
(
n
,
k0
,
hox2
,
wox2
+
1
,
k1
);
add_host
(
n
,
k0
,
hox2
+
1
,
wox2
,
k1
)
=
v
+
add
(
n
,
k0
,
hox2
+
1
,
wox2
,
k1
);
add_host
(
n
,
k0
,
hox2
+
1
,
wox2
+
1
,
k1
)
=
v
+
add
(
n
,
k0
,
hox2
+
1
,
wox2
+
1
,
k1
);
};
make_ParallelTensorFunctor
(
f_nchw
,
out_host
.
mDesc
.
GetLengths
()[
0
],
out_host
.
mDesc
.
GetLengths
()[
1
],
out_host
.
mDesc
.
GetLengths
()[
2
],
out_host
.
mDesc
.
GetLengths
()[
3
],
out_host
.
mDesc
.
GetLengths
()[
4
])(
std
::
thread
::
hardware_concurrency
());
}
template
<
typename
TIn
,
typename
TWei
,
typename
TOut
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
host_direct_convolution_maxpool_nchwc
(
const
Tensor
<
TIn
>&
in
,
const
Tensor
<
TWei
>&
wei
,
const
Tensor
<
TOut
>&
bias
,
Tensor
<
TOut
>&
out_host
,
Tensor
<
TOut
>&
max_host
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
,
const
ck
::
index_t
activ_type
=
0
)
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
auto
f_nchw
=
[
&
](
auto
n
,
auto
k0
,
auto
ho
,
auto
wo
,
auto
k1
)
{
double
v
=
0
;
auto
k
=
k0
*
out_host
.
mDesc
.
GetLengths
()[
4
]
+
k1
;
for
(
int
c0
=
0
;
c0
<
wei
.
mDesc
.
GetLengths
()[
1
];
++
c0
)
{
for
(
int
y
=
0
;
y
<
wei
.
mDesc
.
GetLengths
()[
2
];
++
y
)
{
int
hi
=
ho
*
conv_strides
[
I0
]
+
y
*
conv_dilations
[
I0
]
-
in_left_pads
[
I0
];
for
(
int
x
=
0
;
x
<
wei
.
mDesc
.
GetLengths
()[
3
];
++
x
)
{
int
wi
=
wo
*
conv_strides
[
I1
]
+
x
*
conv_dilations
[
I1
]
-
in_left_pads
[
I1
];
if
(
hi
>=
0
&&
hi
<
in
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
wi
<
in
.
mDesc
.
GetLengths
()[
3
])
{
for
(
int
c1
=
0
;
c1
<
wei
.
mDesc
.
GetLengths
()[
4
];
++
c1
)
{
v
+=
static_cast
<
const
double
>
(
in
(
n
,
c0
,
hi
,
wi
,
c1
))
*
static_cast
<
const
double
>
(
wei
(
k
,
c0
,
y
,
x
,
c1
));
}
}
}
}
}
v
+=
bias
(
k0
,
k1
);
v
=
activ
(
v
,
activ_type
);
out_host
(
n
,
k0
,
ho
,
wo
,
k1
)
=
v
;
};
make_ParallelTensorFunctor
(
f_nchw
,
out_host
.
mDesc
.
GetLengths
()[
0
],
out_host
.
mDesc
.
GetLengths
()[
1
],
out_host
.
mDesc
.
GetLengths
()[
2
],
out_host
.
mDesc
.
GetLengths
()[
3
],
out_host
.
mDesc
.
GetLengths
()[
4
])(
std
::
thread
::
hardware_concurrency
());
auto
maxpool_nchw
=
[
&
](
auto
n
,
auto
k0
,
auto
ho
,
auto
wo
,
auto
k1
)
{
auto
hx
=
ho
*
2
;
auto
wx
=
wo
*
2
;
auto
v0
=
out_host
(
n
,
k0
,
hx
,
wx
,
k1
);
auto
v1
=
out_host
(
n
,
k0
,
hx
,
wx
+
1
,
k1
);
auto
v2
=
out_host
(
n
,
k0
,
hx
+
1
,
wx
,
k1
);
auto
v3
=
out_host
(
n
,
k0
,
hx
+
1
,
wx
+
1
,
k1
);
max_host
(
n
,
k0
,
ho
,
wo
,
k1
)
=
std
::
max
({
v0
,
v1
,
v2
,
v3
});
};
make_ParallelTensorFunctor
(
maxpool_nchw
,
max_host
.
mDesc
.
GetLengths
()[
0
],
max_host
.
mDesc
.
GetLengths
()[
1
],
max_host
.
mDesc
.
GetLengths
()[
2
],
max_host
.
mDesc
.
GetLengths
()[
3
],
max_host
.
mDesc
.
GetLengths
()[
4
])(
std
::
thread
::
hardware_concurrency
());
}
template
<
typename
TIn
,
typename
TWei
,
typename
TOut
,
typename
InLeftPads
,
typename
InRightPads
>
void
host_winograd_3x3_convolution
(
const
Tensor
<
TIn
>&
in_nchw
,
const
Tensor
<
TWei
>&
wei_kcyx
,
Tensor
<
TOut
>&
out_nkhw
,
InLeftPads
,
InRightPads
)
{
using
namespace
ck
;
constexpr
std
::
size_t
HoPerTile
=
2
;
constexpr
std
::
size_t
WoPerTile
=
2
;
std
::
size_t
N
=
in_nchw
.
mDesc
.
GetLengths
()[
0
];
std
::
size_t
C
=
in_nchw
.
mDesc
.
GetLengths
()[
1
];
std
::
size_t
K
=
wei_kcyx
.
mDesc
.
GetLengths
()[
0
];
std
::
size_t
Y
=
wei_kcyx
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
X
=
wei_kcyx
.
mDesc
.
GetLengths
()[
3
];
std
::
size_t
Ho
=
out_nkhw
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
Wo
=
out_nkhw
.
mDesc
.
GetLengths
()[
3
];
index_t
h_pad_low
=
InLeftPads
{}.
Get
(
Number
<
0
>
{});
index_t
w_pad_low
=
InLeftPads
{}.
Get
(
Number
<
1
>
{});
std
::
size_t
HiPerTile
=
HoPerTile
+
Y
-
1
;
std
::
size_t
WiPerTile
=
WoPerTile
+
X
-
1
;
std
::
size_t
HTile
=
(
Ho
+
HoPerTile
-
1
)
/
HoPerTile
;
std
::
size_t
WTile
=
(
Wo
+
WoPerTile
-
1
)
/
WoPerTile
;
Tensor
<
double
>
in_hold
({
N
,
C
,
HTile
,
WTile
,
HiPerTile
,
WiPerTile
});
Tensor
<
double
>
in_transform
({
N
,
C
,
HTile
,
WTile
,
HiPerTile
,
WiPerTile
});
Tensor
<
double
>
wei_transform
({
K
,
C
,
HiPerTile
,
WiPerTile
});
Tensor
<
double
>
out_transform
({
N
,
K
,
HTile
,
WTile
,
HiPerTile
,
HiPerTile
});
Tensor
<
double
>
out_hold
({
N
,
K
,
HTile
,
WTile
,
HoPerTile
,
WoPerTile
});
auto
f_in_hold
=
[
&
](
auto
n
,
auto
c
,
auto
htile
,
auto
wtile
)
{
for
(
int
j
=
0
;
j
<
HiPerTile
;
++
j
)
{
int
hi
=
HoPerTile
*
htile
+
j
-
h_pad_low
;
for
(
int
i
=
0
;
i
<
WiPerTile
;
++
i
)
{
int
wi
=
WoPerTile
*
wtile
+
i
-
w_pad_low
;
if
(
hi
>=
0
&&
hi
<
in_nchw
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
wi
<
in_nchw
.
mDesc
.
GetLengths
()[
3
])
{
in_hold
(
n
,
c
,
htile
,
wtile
,
j
,
i
)
=
in_nchw
(
n
,
c
,
hi
,
wi
);
}
else
{
in_hold
(
n
,
c
,
htile
,
wtile
,
j
,
i
)
=
TIn
(
0
);
}
}
}
};
auto
f_in_transform
=
[
&
](
auto
n
,
auto
c
,
auto
htile
,
auto
wtile
)
{
in_transform
(
n
,
c
,
htile
,
wtile
,
0
,
0
)
=
in_hold
(
n
,
c
,
htile
,
wtile
,
0
,
0
)
-
in_hold
(
n
,
c
,
htile
,
wtile
,
0
,
2
)
-
in_hold
(
n
,
c
,
htile
,
wtile
,
2
,
0
)
+
in_hold
(
n
,
c
,
htile
,
wtile
,
2
,
2
);
in_transform
(
n
,
c
,
htile
,
wtile
,
0
,
1
)
=
in_hold
(
n
,
c
,
htile
,
wtile
,
0
,
1
)
+
in_hold
(
n
,
c
,
htile
,
wtile
,
0
,
2
)
-
in_hold
(
n
,
c
,
htile
,
wtile
,
2
,
1
)
-
in_hold
(
n
,
c
,
htile
,
wtile
,
2
,
2
);
in_transform
(
n
,
c
,
htile
,
wtile
,
0
,
2
)
=
-
in_hold
(
n
,
c
,
htile
,
wtile
,
0
,
1
)
+
in_hold
(
n
,
c
,
htile
,
wtile
,
0
,
2
)
+
in_hold
(
n
,
c
,
htile
,
wtile
,
2
,
1
)
-
in_hold
(
n
,
c
,
htile
,
wtile
,
2
,
2
);
in_transform
(
n
,
c
,
htile
,
wtile
,
0
,
3
)
=
in_hold
(
n
,
c
,
htile
,
wtile
,
0
,
1
)
-
in_hold
(
n
,
c
,
htile
,
wtile
,
0
,
3
)
-
in_hold
(
n
,
c
,
htile
,
wtile
,
2
,
1
)
+
in_hold
(
n
,
c
,
htile
,
wtile
,
2
,
3
);
in_transform
(
n
,
c
,
htile
,
wtile
,
1
,
0
)
=
in_hold
(
n
,
c
,
htile
,
wtile
,
1
,
0
)
-
in_hold
(
n
,
c
,
htile
,
wtile
,
1
,
2
)
+
in_hold
(
n
,
c
,
htile
,
wtile
,
2
,
0
)
-
in_hold
(
n
,
c
,
htile
,
wtile
,
2
,
2
);
in_transform
(
n
,
c
,
htile
,
wtile
,
1
,
1
)
=
in_hold
(
n
,
c
,
htile
,
wtile
,
1
,
1
)
+
in_hold
(
n
,
c
,
htile
,
wtile
,
1
,
2
)
+
in_hold
(
n
,
c
,
htile
,
wtile
,
2
,
1
)
+
in_hold
(
n
,
c
,
htile
,
wtile
,
2
,
2
);
in_transform
(
n
,
c
,
htile
,
wtile
,
1
,
2
)
=
-
in_hold
(
n
,
c
,
htile
,
wtile
,
1
,
1
)
+
in_hold
(
n
,
c
,
htile
,
wtile
,
1
,
2
)
-
in_hold
(
n
,
c
,
htile
,
wtile
,
2
,
1
)
+
in_hold
(
n
,
c
,
htile
,
wtile
,
2
,
2
);
in_transform
(
n
,
c
,
htile
,
wtile
,
1
,
3
)
=
in_hold
(
n
,
c
,
htile
,
wtile
,
1
,
1
)
-
in_hold
(
n
,
c
,
htile
,
wtile
,
1
,
3
)
+
in_hold
(
n
,
c
,
htile
,
wtile
,
2
,
1
)
-
in_hold
(
n
,
c
,
htile
,
wtile
,
2
,
3
);
in_transform
(
n
,
c
,
htile
,
wtile
,
2
,
0
)
=
-
in_hold
(
n
,
c
,
htile
,
wtile
,
1
,
0
)
+
in_hold
(
n
,
c
,
htile
,
wtile
,
1
,
2
)
+
in_hold
(
n
,
c
,
htile
,
wtile
,
2
,
0
)
-
in_hold
(
n
,
c
,
htile
,
wtile
,
2
,
2
);
in_transform
(
n
,
c
,
htile
,
wtile
,
2
,
1
)
=
-
in_hold
(
n
,
c
,
htile
,
wtile
,
1
,
1
)
-
in_hold
(
n
,
c
,
htile
,
wtile
,
1
,
2
)
+
in_hold
(
n
,
c
,
htile
,
wtile
,
2
,
1
)
+
in_hold
(
n
,
c
,
htile
,
wtile
,
2
,
2
);
in_transform
(
n
,
c
,
htile
,
wtile
,
2
,
2
)
=
in_hold
(
n
,
c
,
htile
,
wtile
,
1
,
1
)
-
in_hold
(
n
,
c
,
htile
,
wtile
,
1
,
2
)
-
in_hold
(
n
,
c
,
htile
,
wtile
,
2
,
1
)
+
in_hold
(
n
,
c
,
htile
,
wtile
,
2
,
2
);
in_transform
(
n
,
c
,
htile
,
wtile
,
2
,
3
)
=
-
in_hold
(
n
,
c
,
htile
,
wtile
,
1
,
1
)
+
in_hold
(
n
,
c
,
htile
,
wtile
,
1
,
3
)
+
in_hold
(
n
,
c
,
htile
,
wtile
,
2
,
1
)
-
in_hold
(
n
,
c
,
htile
,
wtile
,
2
,
3
);
in_transform
(
n
,
c
,
htile
,
wtile
,
3
,
0
)
=
in_hold
(
n
,
c
,
htile
,
wtile
,
1
,
0
)
-
in_hold
(
n
,
c
,
htile
,
wtile
,
1
,
2
)
-
in_hold
(
n
,
c
,
htile
,
wtile
,
3
,
0
)
+
in_hold
(
n
,
c
,
htile
,
wtile
,
3
,
2
);
in_transform
(
n
,
c
,
htile
,
wtile
,
3
,
1
)
=
in_hold
(
n
,
c
,
htile
,
wtile
,
1
,
1
)
+
in_hold
(
n
,
c
,
htile
,
wtile
,
1
,
2
)
-
in_hold
(
n
,
c
,
htile
,
wtile
,
3
,
1
)
-
in_hold
(
n
,
c
,
htile
,
wtile
,
3
,
2
);
in_transform
(
n
,
c
,
htile
,
wtile
,
3
,
2
)
=
-
in_hold
(
n
,
c
,
htile
,
wtile
,
1
,
1
)
+
in_hold
(
n
,
c
,
htile
,
wtile
,
1
,
2
)
+
in_hold
(
n
,
c
,
htile
,
wtile
,
3
,
1
)
-
in_hold
(
n
,
c
,
htile
,
wtile
,
3
,
2
);
in_transform
(
n
,
c
,
htile
,
wtile
,
3
,
3
)
=
in_hold
(
n
,
c
,
htile
,
wtile
,
1
,
1
)
-
in_hold
(
n
,
c
,
htile
,
wtile
,
1
,
3
)
-
in_hold
(
n
,
c
,
htile
,
wtile
,
3
,
1
)
+
in_hold
(
n
,
c
,
htile
,
wtile
,
3
,
3
);
};
auto
f_wei_transform
=
[
&
](
auto
k
,
auto
c
)
{
wei_transform
(
k
,
c
,
0
,
0
)
=
double
(
wei_kcyx
(
k
,
c
,
0
,
0
));
wei_transform
(
k
,
c
,
0
,
1
)
=
0.5
*
double
(
wei_kcyx
(
k
,
c
,
0
,
0
))
+
0.5
*
double
(
wei_kcyx
(
k
,
c
,
0
,
1
))
+
0.5
*
double
(
wei_kcyx
(
k
,
c
,
0
,
2
));
wei_transform
(
k
,
c
,
0
,
2
)
=
0.5
*
double
(
wei_kcyx
(
k
,
c
,
0
,
0
))
-
0.5
*
double
(
wei_kcyx
(
k
,
c
,
0
,
1
))
+
0.5
*
double
(
wei_kcyx
(
k
,
c
,
0
,
2
));
wei_transform
(
k
,
c
,
0
,
3
)
=
double
(
wei_kcyx
(
k
,
c
,
0
,
2
));
wei_transform
(
k
,
c
,
1
,
0
)
=
0.5
*
double
(
wei_kcyx
(
k
,
c
,
0
,
0
))
+
0.5
*
double
(
wei_kcyx
(
k
,
c
,
1
,
0
))
+
0.5
*
double
(
wei_kcyx
(
k
,
c
,
2
,
0
));
wei_transform
(
k
,
c
,
1
,
1
)
=
0.25
*
double
(
wei_kcyx
(
k
,
c
,
0
,
0
))
+
0.25
*
double
(
wei_kcyx
(
k
,
c
,
0
,
1
))
+
0.25
*
double
(
wei_kcyx
(
k
,
c
,
0
,
2
))
+
0.25
*
double
(
wei_kcyx
(
k
,
c
,
1
,
0
))
+
0.25
*
double
(
wei_kcyx
(
k
,
c
,
1
,
1
))
+
0.25
*
double
(
wei_kcyx
(
k
,
c
,
1
,
2
))
+
0.25
*
double
(
wei_kcyx
(
k
,
c
,
2
,
0
))
+
0.25
*
double
(
wei_kcyx
(
k
,
c
,
2
,
1
))
+
0.25
*
double
(
wei_kcyx
(
k
,
c
,
2
,
2
));
wei_transform
(
k
,
c
,
1
,
2
)
=
0.25
*
double
(
wei_kcyx
(
k
,
c
,
0
,
0
))
-
0.25
*
double
(
wei_kcyx
(
k
,
c
,
0
,
1
))
+
0.25
*
double
(
wei_kcyx
(
k
,
c
,
0
,
2
))
+
0.25
*
double
(
wei_kcyx
(
k
,
c
,
1
,
0
))
-
0.25
*
double
(
wei_kcyx
(
k
,
c
,
1
,
1
))
+
0.25
*
double
(
wei_kcyx
(
k
,
c
,
1
,
2
))
+
0.25
*
double
(
wei_kcyx
(
k
,
c
,
2
,
0
))
-
0.25
*
double
(
wei_kcyx
(
k
,
c
,
2
,
1
))
+
0.25
*
double
(
wei_kcyx
(
k
,
c
,
2
,
2
));
wei_transform
(
k
,
c
,
1
,
3
)
=
0.5
*
double
(
wei_kcyx
(
k
,
c
,
0
,
2
))
+
0.5
*
double
(
wei_kcyx
(
k
,
c
,
1
,
2
))
+
0.5
*
double
(
wei_kcyx
(
k
,
c
,
2
,
2
));
wei_transform
(
k
,
c
,
2
,
0
)
=
0.5
*
double
(
wei_kcyx
(
k
,
c
,
0
,
0
))
-
0.5
*
double
(
wei_kcyx
(
k
,
c
,
1
,
0
))
+
0.5
*
double
(
wei_kcyx
(
k
,
c
,
2
,
0
));
wei_transform
(
k
,
c
,
2
,
1
)
=
0.25
*
double
(
wei_kcyx
(
k
,
c
,
0
,
0
))
+
0.25
*
double
(
wei_kcyx
(
k
,
c
,
0
,
1
))
+
0.25
*
double
(
wei_kcyx
(
k
,
c
,
0
,
2
))
-
0.25
*
double
(
wei_kcyx
(
k
,
c
,
1
,
0
))
-
0.25
*
double
(
wei_kcyx
(
k
,
c
,
1
,
1
))
-
0.25
*
double
(
wei_kcyx
(
k
,
c
,
1
,
2
))
+
0.25
*
double
(
wei_kcyx
(
k
,
c
,
2
,
0
))
+
0.25
*
double
(
wei_kcyx
(
k
,
c
,
2
,
1
))
+
0.25
*
double
(
wei_kcyx
(
k
,
c
,
2
,
2
));
wei_transform
(
k
,
c
,
2
,
2
)
=
0.25
*
double
(
wei_kcyx
(
k
,
c
,
0
,
0
))
-
0.25
*
double
(
wei_kcyx
(
k
,
c
,
0
,
1
))
+
0.25
*
double
(
wei_kcyx
(
k
,
c
,
0
,
2
))
-
0.25
*
double
(
wei_kcyx
(
k
,
c
,
1
,
0
))
+
0.25
*
double
(
wei_kcyx
(
k
,
c
,
1
,
1
))
-
0.25
*
double
(
wei_kcyx
(
k
,
c
,
1
,
2
))
+
0.25
*
double
(
wei_kcyx
(
k
,
c
,
2
,
0
))
-
0.25
*
double
(
wei_kcyx
(
k
,
c
,
2
,
1
))
+
0.25
*
double
(
wei_kcyx
(
k
,
c
,
2
,
2
));
wei_transform
(
k
,
c
,
2
,
3
)
=
0.5
*
double
(
wei_kcyx
(
k
,
c
,
0
,
2
))
-
0.5
*
double
(
wei_kcyx
(
k
,
c
,
1
,
2
))
+
0.5
*
double
(
wei_kcyx
(
k
,
c
,
2
,
2
));
wei_transform
(
k
,
c
,
3
,
0
)
=
double
(
wei_kcyx
(
k
,
c
,
2
,
0
));
wei_transform
(
k
,
c
,
3
,
1
)
=
0.5
*
double
(
wei_kcyx
(
k
,
c
,
2
,
0
))
+
0.5
*
double
(
wei_kcyx
(
k
,
c
,
2
,
1
))
+
0.5
*
double
(
wei_kcyx
(
k
,
c
,
2
,
2
));
wei_transform
(
k
,
c
,
3
,
2
)
=
0.5
*
double
(
wei_kcyx
(
k
,
c
,
2
,
0
))
-
0.5
*
double
(
wei_kcyx
(
k
,
c
,
2
,
1
))
+
0.5
*
double
(
wei_kcyx
(
k
,
c
,
2
,
2
));
wei_transform
(
k
,
c
,
3
,
3
)
=
double
(
wei_kcyx
(
k
,
c
,
2
,
2
));
};
auto
f_out_transform
=
[
&
](
auto
n
,
auto
k
,
auto
htile
,
auto
wtile
)
{
for
(
int
j
=
0
;
j
<
HiPerTile
;
++
j
)
{
for
(
int
i
=
0
;
i
<
WiPerTile
;
++
i
)
{
double
v
=
0
;
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
v
+=
in_transform
(
n
,
c
,
htile
,
wtile
,
j
,
i
)
*
wei_transform
(
k
,
c
,
j
,
i
);
}
out_transform
(
n
,
k
,
htile
,
wtile
,
j
,
i
)
=
v
;
}
}
};
auto
f_out_hold
=
[
&
](
auto
n
,
auto
k
,
auto
htile
,
auto
wtile
)
{
out_hold
(
n
,
k
,
htile
,
wtile
,
0
,
0
)
=
out_transform
(
n
,
k
,
htile
,
wtile
,
0
,
0
)
+
out_transform
(
n
,
k
,
htile
,
wtile
,
0
,
1
)
+
out_transform
(
n
,
k
,
htile
,
wtile
,
0
,
2
)
+
out_transform
(
n
,
k
,
htile
,
wtile
,
1
,
0
)
+
out_transform
(
n
,
k
,
htile
,
wtile
,
1
,
1
)
+
out_transform
(
n
,
k
,
htile
,
wtile
,
1
,
2
)
+
out_transform
(
n
,
k
,
htile
,
wtile
,
2
,
0
)
+
out_transform
(
n
,
k
,
htile
,
wtile
,
2
,
1
)
+
out_transform
(
n
,
k
,
htile
,
wtile
,
2
,
2
);
out_hold
(
n
,
k
,
htile
,
wtile
,
0
,
1
)
=
out_transform
(
n
,
k
,
htile
,
wtile
,
0
,
1
)
-
out_transform
(
n
,
k
,
htile
,
wtile
,
0
,
2
)
-
out_transform
(
n
,
k
,
htile
,
wtile
,
0
,
3
)
+
out_transform
(
n
,
k
,
htile
,
wtile
,
1
,
1
)
-
out_transform
(
n
,
k
,
htile
,
wtile
,
1
,
2
)
-
out_transform
(
n
,
k
,
htile
,
wtile
,
1
,
3
)
+
out_transform
(
n
,
k
,
htile
,
wtile
,
2
,
1
)
-
out_transform
(
n
,
k
,
htile
,
wtile
,
2
,
2
)
-
out_transform
(
n
,
k
,
htile
,
wtile
,
2
,
3
);
out_hold
(
n
,
k
,
htile
,
wtile
,
1
,
0
)
=
out_transform
(
n
,
k
,
htile
,
wtile
,
1
,
0
)
+
out_transform
(
n
,
k
,
htile
,
wtile
,
1
,
1
)
+
out_transform
(
n
,
k
,
htile
,
wtile
,
1
,
2
)
-
out_transform
(
n
,
k
,
htile
,
wtile
,
2
,
0
)
-
out_transform
(
n
,
k
,
htile
,
wtile
,
2
,
1
)
-
out_transform
(
n
,
k
,
htile
,
wtile
,
2
,
2
)
-
out_transform
(
n
,
k
,
htile
,
wtile
,
3
,
0
)
-
out_transform
(
n
,
k
,
htile
,
wtile
,
3
,
1
)
-
out_transform
(
n
,
k
,
htile
,
wtile
,
3
,
2
);
out_hold
(
n
,
k
,
htile
,
wtile
,
1
,
1
)
=
out_transform
(
n
,
k
,
htile
,
wtile
,
1
,
1
)
-
out_transform
(
n
,
k
,
htile
,
wtile
,
1
,
2
)
-
out_transform
(
n
,
k
,
htile
,
wtile
,
1
,
3
)
-
out_transform
(
n
,
k
,
htile
,
wtile
,
2
,
1
)
+
out_transform
(
n
,
k
,
htile
,
wtile
,
2
,
2
)
+
out_transform
(
n
,
k
,
htile
,
wtile
,
2
,
3
)
-
out_transform
(
n
,
k
,
htile
,
wtile
,
3
,
1
)
+
out_transform
(
n
,
k
,
htile
,
wtile
,
3
,
2
)
+
out_transform
(
n
,
k
,
htile
,
wtile
,
3
,
3
);
};
auto
f_out
=
[
&
](
auto
n
,
auto
k
,
auto
htile
,
auto
wtile
)
{
for
(
int
j
=
0
;
j
<
HoPerTile
;
++
j
)
{
std
::
size_t
ho
=
HoPerTile
*
htile
+
j
;
for
(
int
i
=
0
;
i
<
WoPerTile
;
++
i
)
{
std
::
size_t
wo
=
WoPerTile
*
wtile
+
i
;
out_nkhw
(
n
,
k
,
ho
,
wo
)
=
out_hold
(
n
,
k
,
htile
,
wtile
,
j
,
i
);
}
}
};
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
make_ParallelTensorFunctor
(
f_in_hold
,
N
,
C
,
HTile
,
WTile
)(
num_thread
);
make_ParallelTensorFunctor
(
f_in_transform
,
N
,
C
,
HTile
,
WTile
)(
num_thread
);
make_ParallelTensorFunctor
(
f_wei_transform
,
K
,
C
)(
num_thread
);
make_ParallelTensorFunctor
(
f_out_transform
,
N
,
K
,
HTile
,
WTile
)(
num_thread
);
make_ParallelTensorFunctor
(
f_out_hold
,
N
,
K
,
HTile
,
WTile
)(
num_thread
);
make_ParallelTensorFunctor
(
f_out
,
N
,
K
,
HTile
,
WTile
)(
num_thread
);
}
}
host/host_tensor/include/host_conv_bwd_data.hpp
deleted
100644 → 0
View file @
41852668
#pragma once
#include "host_tensor.hpp"
template
<
typename
TIn
,
typename
TWei
,
typename
TOut
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
host_direct_convolution_backward_data
(
Tensor
<
TIn
>&
in
,
const
Tensor
<
TWei
>&
wei
,
const
Tensor
<
TOut
>&
out
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
/* in_right_pads */
,
const
ConvTensorLayout
layout
=
ConvTensorLayout
::
NCHW
)
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
auto
f_nchw
=
[
&
](
auto
n
,
auto
c
,
auto
hi
,
auto
wi
)
{
std
::
size_t
K
=
wei
.
mDesc
.
GetLengths
()[
I0
];
std
::
size_t
Y
=
wei
.
mDesc
.
GetLengths
()[
I2
];
std
::
size_t
X
=
wei
.
mDesc
.
GetLengths
()[
I3
];
std
::
size_t
Ho
=
out
.
mDesc
.
GetLengths
()[
I2
];
std
::
size_t
Wo
=
out
.
mDesc
.
GetLengths
()[
I3
];
double
v
=
0
;
for
(
int
y
=
0
;
y
<
Y
;
++
y
)
{
int
h_tmp
=
hi
+
in_left_pads
[
I0
]
-
y
*
conv_dilations
[
I0
];
if
(
h_tmp
%
conv_strides
[
I0
]
==
0
)
{
int
ho
=
h_tmp
/
conv_strides
[
I0
];
if
(
ho
>=
0
&&
ho
<
Ho
)
{
for
(
int
x
=
0
;
x
<
X
;
++
x
)
{
int
w_tmp
=
wi
+
in_left_pads
[
I1
]
-
x
*
conv_dilations
[
I1
];
if
(
w_tmp
%
conv_strides
[
I1
]
==
0
)
{
int
wo
=
w_tmp
/
conv_strides
[
I1
];
if
(
wo
>=
0
&&
wo
<
Wo
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
v
+=
out
(
n
,
k
,
ho
,
wo
)
*
wei
(
k
,
c
,
y
,
x
);
}
}
}
}
}
}
}
in
(
n
,
c
,
hi
,
wi
)
=
v
;
};
auto
f_nhwc
=
[
&
](
auto
n
,
auto
hi
,
auto
wi
,
auto
c
)
{
std
::
size_t
K
=
wei
.
mDesc
.
GetLengths
()[
I0
];
std
::
size_t
Y
=
wei
.
mDesc
.
GetLengths
()[
I1
];
std
::
size_t
X
=
wei
.
mDesc
.
GetLengths
()[
I2
];
std
::
size_t
Ho
=
out
.
mDesc
.
GetLengths
()[
I1
];
std
::
size_t
Wo
=
out
.
mDesc
.
GetLengths
()[
I2
];
double
v
=
0
;
for
(
int
y
=
0
;
y
<
Y
;
++
y
)
{
int
h_tmp
=
hi
+
in_left_pads
[
I0
]
-
y
*
conv_dilations
[
I0
];
if
(
h_tmp
%
conv_strides
[
I0
]
==
0
)
{
int
ho
=
h_tmp
/
conv_strides
[
I0
];
if
(
ho
>=
0
&&
ho
<
Ho
)
{
for
(
int
x
=
0
;
x
<
X
;
++
x
)
{
int
w_tmp
=
wi
+
in_left_pads
[
I1
]
-
x
*
conv_dilations
[
I1
];
if
(
w_tmp
%
conv_strides
[
I1
]
==
0
)
{
int
wo
=
w_tmp
/
conv_strides
[
I1
];
if
(
wo
>=
0
&&
wo
<
Wo
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
v
+=
out
(
n
,
ho
,
wo
,
k
)
*
wei
(
k
,
y
,
x
,
c
);
}
}
}
}
}
}
}
in
(
n
,
hi
,
wi
,
c
)
=
v
;
};
if
(
layout
==
ConvTensorLayout
::
NCHW
)
{
make_ParallelTensorFunctor
(
f_nchw
,
in
.
mDesc
.
GetLengths
()[
0
],
in
.
mDesc
.
GetLengths
()[
1
],
in
.
mDesc
.
GetLengths
()[
2
],
in
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
}
else
if
(
layout
==
ConvTensorLayout
::
NHWC
)
{
make_ParallelTensorFunctor
(
f_nhwc
,
in
.
mDesc
.
GetLengths
()[
0
],
in
.
mDesc
.
GetLengths
()[
1
],
in
.
mDesc
.
GetLengths
()[
2
],
in
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
}
else
{
throw
std
::
runtime_error
(
"wrong! not supported layout"
);
}
}
host/host_tensor/include/host_conv_bwd_weight.hpp
deleted
100644 → 0
View file @
41852668
#pragma once
#include "host_tensor.hpp"
template
<
typename
TOut
,
typename
TIn
,
typename
TWei
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
host_direct_convolution_backward_weights
(
const
Tensor
<
TOut
>&
out
,
const
Tensor
<
TIn
>&
in
,
Tensor
<
TWei
>&
wei
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
,
const
ConvTensorLayout
layout
=
ConvTensorLayout
::
NCHW
)
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
auto
f_kcyx
=
[
&
](
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
double
v
=
0
;
for
(
int
n
=
0
;
n
<
out
.
mDesc
.
GetLengths
()[
0
];
++
n
)
{
for
(
int
ho
=
0
;
ho
<
out
.
mDesc
.
GetLengths
()[
2
];
++
ho
)
{
int
hi
=
ho
*
conv_strides
[
I0
]
+
y
*
conv_dilations
[
I0
]
-
in_left_pads
[
I0
];
for
(
int
wo
=
0
;
wo
<
out
.
mDesc
.
GetLengths
()[
3
];
++
wo
)
{
int
wi
=
wo
*
conv_strides
[
I1
]
+
x
*
conv_dilations
[
I1
]
-
in_left_pads
[
I1
];
if
(
hi
>=
0
&&
hi
<
in
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
wi
<
in
.
mDesc
.
GetLengths
()[
3
])
{
v
+=
static_cast
<
const
double
>
(
in
(
n
,
c
,
hi
,
wi
))
*
static_cast
<
const
double
>
(
out
(
n
,
k
,
ho
,
wo
));
}
}
}
}
wei
(
k
,
c
,
y
,
x
)
=
v
;
};
auto
f_kyxc
=
[
&
](
auto
k
,
auto
y
,
auto
x
,
auto
c
)
{
double
v
=
0
;
for
(
int
n
=
0
;
n
<
out
.
mDesc
.
GetLengths
()[
0
];
++
n
)
{
for
(
int
ho
=
0
;
ho
<
out
.
mDesc
.
GetLengths
()[
1
];
++
ho
)
{
int
hi
=
ho
*
conv_strides
[
I0
]
+
y
*
conv_dilations
[
I0
]
-
in_left_pads
[
I0
];
for
(
int
wo
=
0
;
wo
<
out
.
mDesc
.
GetLengths
()[
2
];
++
wo
)
{
int
wi
=
wo
*
conv_strides
[
I1
]
+
x
*
conv_dilations
[
I1
]
-
in_left_pads
[
I1
];
if
(
hi
>=
0
&&
hi
<
in
.
mDesc
.
GetLengths
()[
1
]
&&
wi
>=
0
&&
wi
<
in
.
mDesc
.
GetLengths
()[
2
])
{
v
+=
static_cast
<
const
double
>
(
in
(
n
,
hi
,
wi
,
c
))
*
static_cast
<
const
double
>
(
out
(
n
,
ho
,
wo
,
k
));
}
}
}
}
wei
(
k
,
y
,
x
,
c
)
=
v
;
};
if
(
layout
==
ConvTensorLayout
::
NCHW
)
{
make_ParallelTensorFunctor
(
f_kcyx
,
wei
.
mDesc
.
GetLengths
()[
0
],
wei
.
mDesc
.
GetLengths
()[
1
],
wei
.
mDesc
.
GetLengths
()[
2
],
wei
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
}
else
if
(
layout
==
ConvTensorLayout
::
NHWC
)
{
make_ParallelTensorFunctor
(
f_kyxc
,
wei
.
mDesc
.
GetLengths
()[
0
],
wei
.
mDesc
.
GetLengths
()[
1
],
wei
.
mDesc
.
GetLengths
()[
2
],
wei
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
}
else
{
throw
std
::
runtime_error
(
"wrong! not supported layout"
);
}
}
host/host_tensor/include/host_gemm.hpp
View file @
ed068043
...
@@ -157,3 +157,26 @@ void host_gemm(const Tensor<AType>& a,
...
@@ -157,3 +157,26 @@ void host_gemm(const Tensor<AType>& a,
throw
std
::
runtime_error
(
"wrong! not supported layout"
);
throw
std
::
runtime_error
(
"wrong! not supported layout"
);
}
}
}
}
template
<
typename
AType
,
typename
BType
,
typename
CType
>
void
host_gemm_mk_kn_mn
(
const
Tensor
<
AType
>&
a_m_k
,
const
Tensor
<
BType
>&
b_k_n
,
Tensor
<
CType
>&
c_m_n
)
{
auto
f_mk_kn_mn
=
[
&
](
auto
m
,
auto
n
)
{
const
int
K
=
a_m_k
.
mDesc
.
GetLengths
()[
1
];
double
v
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
v
+=
static_cast
<
const
double
>
(
a_m_k
(
m
,
k
))
*
static_cast
<
const
double
>
(
b_k_n
(
k
,
n
));
}
c_m_n
(
m
,
n
)
=
v
;
};
make_ParallelTensorFunctor
(
f_mk_kn_mn
,
c_m_n
.
mDesc
.
GetLengths
()[
0
],
c_m_n
.
mDesc
.
GetLengths
()[
1
])(
std
::
thread
::
hardware_concurrency
());
}
host/host_tensor/include/host_tensor.hpp
View file @
ed068043
...
@@ -120,6 +120,8 @@ struct HostTensorDescriptor
...
@@ -120,6 +120,8 @@ struct HostTensorDescriptor
return
std
::
inner_product
(
iss
.
begin
(),
iss
.
end
(),
mStrides
.
begin
(),
std
::
size_t
{
0
});
return
std
::
inner_product
(
iss
.
begin
(),
iss
.
end
(),
mStrides
.
begin
(),
std
::
size_t
{
0
});
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
HostTensorDescriptor
&
desc
);
private:
private:
std
::
vector
<
std
::
size_t
>
mLens
;
std
::
vector
<
std
::
size_t
>
mLens
;
std
::
vector
<
std
::
size_t
>
mStrides
;
std
::
vector
<
std
::
size_t
>
mStrides
;
...
@@ -224,7 +226,7 @@ struct Tensor
...
@@ -224,7 +226,7 @@ struct Tensor
Tensor
(
const
HostTensorDescriptor
&
desc
)
:
mDesc
(
desc
),
mData
(
mDesc
.
GetElementSpace
())
{}
Tensor
(
const
HostTensorDescriptor
&
desc
)
:
mDesc
(
desc
),
mData
(
mDesc
.
GetElementSpace
())
{}
template
<
typename
G
>
template
<
typename
G
>
void
GenerateTensorValue
(
G
g
,
std
::
size_t
num_thread
=
1
)
void
GenerateTensorValue
(
G
g
,
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
()
)
{
{
switch
(
mDesc
.
GetNumOfDimension
())
switch
(
mDesc
.
GetNumOfDimension
())
{
{
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment