Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5ce317cb
Commit
5ce317cb
authored
Oct 07, 2021
by
Jing Zhang
Browse files
add fwd_driver_offline_nchwc
parents
71bc108d
b2dc55f8
Changes
67
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3537 additions
and
550 deletions
+3537
-550
host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp
...kward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp
+193
-78
host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp
...d_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp
+389
-0
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
...ard_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
+17
-12
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
...on_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
+68
-8
host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
...ward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
+32
-58
host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp
host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp
+289
-45
host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp
host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp
+263
-0
host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp
host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp
+289
-45
host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp
host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp
+263
-0
host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp
host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp
+290
-46
host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp
host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp
+291
-0
host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp
host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp
+337
-48
host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp
host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp
+347
-0
host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
...ward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
+3
-3
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
+14
-6
host/driver_offline/src/conv_bwd_driver_offline.cpp
host/driver_offline/src/conv_bwd_driver_offline.cpp
+42
-17
host/driver_offline/src/conv_fwd_driver_offline.cpp
host/driver_offline/src/conv_fwd_driver_offline.cpp
+12
-87
host/driver_offline/src/conv_fwd_driver_offline_nchwc.cpp
host/driver_offline/src/conv_fwd_driver_offline_nchwc.cpp
+306
-0
host/driver_offline/src/conv_wrw_driver_offline.cpp
host/driver_offline/src/conv_wrw_driver_offline.cpp
+2
-1
host/driver_offline/src/gemm_driver_offline.cpp
host/driver_offline/src/gemm_driver_offline.cpp
+90
-96
No files found.
host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp
View file @
5ce317cb
...
@@ -49,7 +49,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
...
@@ -49,7 +49,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
#if 0
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
// [M, N, K0, K1] = [256, 128, 4, 4]
, C = 128,
for fp32
constexpr index_t BlockSize = 256;
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmMPerBlock = 256;
...
@@ -77,7 +77,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
...
@@ -77,7 +77,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif
0
#elif
0
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
// [M, N, K0, K1] = [128, 128, 4, 4]
, C = 64,
for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
...
@@ -104,8 +104,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
...
@@ -104,8 +104,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif
1
#elif
0
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
// [M, N, K0, K1] = [256, 128, 4, 8]
, C = 128,
for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
...
@@ -133,7 +133,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
...
@@ -133,7 +133,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 1
#elif 1
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
// [M, N, K0, K1] = [128, 256, 4, 8]
, C = 128,
for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
...
@@ -160,23 +160,91 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
...
@@ -160,23 +160,91 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#endif
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
const
auto
descs
=
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
2
;
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk
(
out_n_ho_wo_k_desc
,
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
wei_k_y_x_c_desc
,
in_n_hi_wi_c_desc
,
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
conv_strides
,
#elif 0
conv_dilations
,
// [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16
in_left_pads
,
constexpr
index_t
BlockSize
=
128
;
in_right_pads
,
I0
,
constexpr
index_t
GemmMPerBlock
=
128
;
I0
,
constexpr
index_t
GemmNPerBlock
=
64
;
Number
<
GemmK1
>
{});
constexpr
index_t
GemmKPerBlock
=
4
;
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
constexpr
index_t
GemmMPerWave
=
32
;
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
constexpr
index_t
GemmNPerWave
=
32
;
const
auto
in_gemmm_gemmn_grid_desc
=
descs
[
I2
];
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
2
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#endif
// HACK: hacks that control index calculation when iterating over A, B, C matrix
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
...
@@ -185,7 +253,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
...
@@ -185,7 +253,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: gemmk1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: gemmk1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
// 0-: gemmk0
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
// 0-: gemmk0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
// 1-: gemmm
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
// 1-: gemmm
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: gemmk1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-:
// gemmk1
constexpr
auto
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
=
constexpr
auto
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: gemmk0
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: gemmk0
...
@@ -215,7 +284,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
...
@@ -215,7 +284,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
//clang-format on
//
clang-format on
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
>
{};
...
@@ -225,64 +294,110 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
...
@@ -225,64 +294,110 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
const
auto
ConvStrideH
=
conv_strides
[
I0
];
BlockSize
,
const
auto
ConvStrideW
=
conv_strides
[
I1
];
TInWei
,
TAcc
,
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
TOut
,
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_desc
),
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_desc
),
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
decltype
(
in_gemmm_gemmn_grid_desc
),
GemmMPerBlock
,
const
auto
YTilda
=
ConvStrideH
/
GcdStrideDilationH
;
GemmNPerBlock
,
const
auto
XTilda
=
ConvStrideW
/
GcdStrideDilationW
;
GemmKPerBlock
,
GemmMPerWave
,
float
ave_time
=
0
;
GemmNPerWave
,
GemmK1
,
for
(
index_t
i_ytilda
=
0
;
i_ytilda
<
YTilda
;
++
i_ytilda
)
MRepeat
,
{
NRepeat
,
for
(
index_t
i_xtilda
=
0
;
i_xtilda
<
XTilda
;
++
i_xtilda
)
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
,
{
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
,
const
auto
descs
=
Sequence
<
1
,
0
,
2
>
,
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk
(
Sequence
<
1
,
0
,
2
>
,
out_n_ho_wo_k_desc
,
2
,
wei_k_y_x_c_desc
,
GemmABlockTransferSrcScalarPerVector_GemmK1
,
in_n_hi_wi_c_desc
,
GemmABlockTransferDstScalarPerVector_GemmK1
,
conv_strides
,
false
,
// don't move back src coordinate after threadwise copy
conv_dilations
,
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
,
in_left_pads
,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
,
in_right_pads
,
Sequence
<
2
,
0
,
1
>
,
i_ytilda
,
Sequence
<
0
,
2
,
1
>
,
i_xtilda
,
1
,
Number
<
GemmK1
>
{});
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmK1
,
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
false
,
// don't move back src coordinate after threadwise copy
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
const
auto
in_gemmm_gemmn_grid_desc
=
descs
[
I2
];
const
auto
GemmK0
=
out_gemmk0_gemmm_gemmk1_grid_desc
.
GetLength
(
I0
);
if
(
GemmK0
!=
0
)
{
ave_time
+=
driver_gemm_xdlops_v2r3
<
BlockSize
,
TInWei
,
TAcc
,
TOut
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_desc
),
decltype
(
in_gemmm_gemmn_grid_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerWave
,
GemmNPerWave
,
GemmK1
,
MRepeat
,
NRepeat
,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
GemmABlockTransferSrcScalarPerVector_GemmK1
,
GemmABlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
,
Sequence
<
2
,
0
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
#if 0
#if 0
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
#else
#else
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
#endif
#endif
7
,
7
,
GemmCThreadTransferDstScalarPerVector
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
),
decltype
(
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
true
// CAccessOrderMRepeatNRepeat
true
,
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
false
,
// ABlockLdsExtraM
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
false
// BBlockLdsExtraN
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
>
(
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
out_gemmk0_gemmm_gemmk1_grid_desc
,
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
wei_gemmk0_gemmn_gemmk1_grid_desc
,
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
in_gemmm_gemmn_grid_desc
,
out_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmk0_gemmm_gemmk1_grid_step_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
,
in_gemmm_gemmn_grid_desc
,
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
out_gemmk0_gemmm_gemmk1_grid_step_hacks
,
nrepeat
);
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
,
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
}
}
}
{
{
const
auto
N
=
out_n_ho_wo_k_lengths
[
I0
];
const
auto
N
=
out_n_ho_wo_k_lengths
[
I0
];
...
...
host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp
0 → 100644
View file @
5ce317cb
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
TInWei
,
typename
TAcc
,
typename
TOut
,
typename
InLengths
,
typename
WeiLengths
,
typename
OutLengths
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1
(
const
InLengths
&
in_n_hi_wi_c_lengths
,
const
WeiLengths
&
wei_k_y_x_c_lengths
,
const
OutLengths
&
out_n_ho_wo_k_lengths
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
,
const
InLeftPads
&
,
const
InRightPads
&
,
Tensor
<
TInWei
>&
in_n_hi_wi_c
,
const
Tensor
<
TInWei
>&
wei_k_y_x_c
,
const
Tensor
<
TOut
>&
out_n_ho_wo_k
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
DeviceMem
in_n_hi_wi_c_device_buf
(
sizeof
(
TInWei
)
*
in_n_hi_wi_c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_y_x_c_device_buf
(
sizeof
(
TInWei
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_ho_wo_k_device_buf
(
sizeof
(
TOut
)
*
out_n_ho_wo_k
.
mDesc
.
GetElementSpace
());
in_n_hi_wi_c_device_buf
.
ToDevice
(
in_n_hi_wi_c
.
mData
.
data
());
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor_packed
(
in_n_hi_wi_c_lengths
);
const
auto
wei_k_y_x_c_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_y_x_c_lengths
);
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif
0
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
4
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
4
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
2
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
2
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
2
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
2
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#endif
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: gemmk0
Sequence
<
0
,
0
,
0
>
{},
// 1+: gemmm
Sequence
<
0
,
0
,
0
>
{}),
// 2+: gemmk1
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: gemmk0
Sequence
<
0
,
0
,
0
>
{},
// 1-: gemmm
Sequence
<
0
,
0
,
0
>
{}));
// 2-: gemmk1
constexpr
auto
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: gemmk0
Sequence
<
0
,
0
,
0
>
{},
// 1+: gemmn
Sequence
<
0
,
0
,
0
>
{}),
// 2+: gemmk1
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: Gemmk0
Sequence
<
0
,
0
,
0
>
{},
// 1-: Gemmn
Sequence
<
0
,
0
,
0
>
{}));
// 2-: Gemmk1
// clang-format off
constexpr
auto
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
// clang-format on
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
const
auto
descs
=
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1
(
out_n_ho_wo_k_desc
,
wei_k_y_x_c_desc
,
in_n_hi_wi_c_desc
,
conv_strides
,
Number
<
GemmK1
>
{});
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
const
auto
in_gemmm_gemmn_grid_desc
=
descs
[
I2
];
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
TInWei
,
TAcc
,
TOut
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_desc
),
decltype
(
in_gemmm_gemmn_grid_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerWave
,
GemmNPerWave
,
GemmK1
,
MRepeat
,
NRepeat
,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
GemmABlockTransferSrcScalarPerVector_GemmK1
,
GemmABlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
,
Sequence
<
2
,
0
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
#if 0
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
#else
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
#endif
7
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
),
decltype
(
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
true
,
// CAccessOrderMRepeatNRepeat
false
,
// ABlockLdsExtraM
false
// BBlockLdsExtraN
>
(
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
out_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
out_gemmk0_gemmm_gemmk1_grid_step_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
,
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
{
const
auto
N
=
out_n_ho_wo_k_lengths
[
I0
];
const
auto
K
=
out_n_ho_wo_k_lengths
[
I3
];
const
auto
C
=
wei_k_y_x_c_lengths
[
I3
];
const
auto
Ho
=
out_n_ho_wo_k_lengths
[
I1
];
const
auto
Wo
=
out_n_ho_wo_k_lengths
[
I2
];
const
auto
Y
=
wei_k_y_x_c_lengths
[
I1
];
const
auto
X
=
wei_k_y_x_c_lengths
[
I2
];
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
}
// copy result back to host
in_n_hi_wi_c_device_buf
.
FromDevice
(
in_n_hi_wi_c
.
mData
.
data
());
}
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
View file @
5ce317cb
...
@@ -203,18 +203,23 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
...
@@ -203,18 +203,23 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
decltype
(
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
false
>
(
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
false
,
// CAccessOrderMRepeatNRepeat
static_cast
<
TInWei
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
true
,
// ABlockLdsExtraM
static_cast
<
TInWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
true
// BBlockLdsExtraN
out_gemmk0_gemmm_gemmk1_grid_desc
,
>
(
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
in_gemmk0_gemmn_gemmk1_grid_desc
,
static_cast
<
TInWei
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
wei_gemmm_gemmn_grid_desc
,
static_cast
<
TInWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
out_gemmk0_gemmm_gemmk1_grid_step_hacks
,
out_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
wei_gemmm_gemmn_grid_desc
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
nrepeat
);
out_gemmk0_gemmm_gemmk1_grid_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
float
perf
=
static_cast
<
float
>
(
calculate_convolution_flops
(
float
perf
=
static_cast
<
float
>
(
calculate_convolution_flops
(
in_n_c_hi_wi_desc
,
wei_k_c_y_x_desc
,
out_n_k_ho_wo_desc
))
/
in_n_c_hi_wi_desc
,
wei_k_c_y_x_desc
,
out_n_k_ho_wo_desc
))
/
...
...
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
View file @
5ce317cb
...
@@ -49,7 +49,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
...
@@ -49,7 +49,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
#if 0
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
// [M, N, K0, K1] = [256, 128, 4, 4]
, C = 128,
for fp32
constexpr index_t BlockSize = 256;
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmMPerBlock = 256;
...
@@ -77,7 +77,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
...
@@ -77,7 +77,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif
0
#elif
0
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
// [M, N, K0, K1] = [128, 128, 4, 4]
, C = 128,
for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
...
@@ -105,7 +105,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
...
@@ -105,7 +105,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
#elif 0
// [M, N, K0, K1] = [256, 256, 4, 8] for fp16
// [M, N, K0, K1] = [256, 256, 4, 8]
, C = 256,
for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
...
@@ -133,7 +133,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
...
@@ -133,7 +133,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
#elif 0
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
// [M, N, K0, K1] = [256, 128, 4, 8]
, C = 128,
for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
...
@@ -160,8 +160,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
...
@@ -160,8 +160,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif
1
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
// [M, N, K0, K1] = [128, 256, 4, 8]
, C = 128,
for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
...
@@ -189,7 +189,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
...
@@ -189,7 +189,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 1
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
// [M, N, K0, K1] = [128, 128, 4, 8]
, C = 64,
for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
...
@@ -215,6 +215,62 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
...
@@ -215,6 +215,62 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#endif
#endif
...
@@ -316,13 +372,17 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
...
@@ -316,13 +372,17 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
decltype
(
out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
false
,
// CAccessOrderMRepeatNRepeat
true
,
// ABlockLdsExtraM
true
// BBlockLdsExtraN
>
(
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
>
(
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
in_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
out_gemmm_gemmn_grid_desc
,
out_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
in_gemmk0_gemmm_gemmk1_grid_step_hacks
,
in_gemmk0_gemmm_gemmk1_grid_step_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
,
out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
...
...
host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
→
host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nc
0
hw
c1
_kc
0
yx
c1
_nk
0
hw
k1
.hpp
View file @
5ce317cb
#include <unistd.h>
#include <unistd.h>
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp"
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nc
0
hw
c1
_kc
0
yx
c1
_nk
0
hw
k1
.hpp"
template
<
typename
TInWei
,
template
<
typename
TInWei
,
typename
TAcc
,
typename
TAcc
,
typename
TOut
,
typename
TOut
,
ck
::
index_t
InWeiVectorSize
,
ck
::
index_t
activ_type
,
ck
::
index_t
activ_type
,
typename
InLengths
,
typename
InLengths
,
typename
WeiLengths
,
typename
WeiLengths
,
...
@@ -15,17 +14,17 @@ template <typename TInWei,
...
@@ -15,17 +14,17 @@ template <typename TInWei,
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
>
typename
InRightPads
>
void
device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw
(
void
device_convolution_forward_implicit_gemm_v5r1_dlops_nc
0
hw
c1
_kc
0
yx
c1
_nk
0
hw
k1
(
const
InLengths
&
in_n_c_hi_wi_lengths
,
const
InLengths
&
in_n_c
0
_hi_wi_
c1_
lengths
,
const
WeiLengths
&
wei_k_c_y_x_lengths
,
const
WeiLengths
&
wei_k_c
0
_y_x_
c1_
lengths
,
const
OutLengths
&
out_n_k_ho_wo_lengths
,
const
OutLengths
&
out_n_k
0
_ho_wo_
k1_
lengths
,
const
ConvStrides
&
conv_strides
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
InRightPads
&
in_right_pads
,
const
Tensor
<
TInWei
>&
in_n_c_hi_wi
,
const
Tensor
<
TInWei
>&
in_n_c
0
_hi_wi
_c1
,
const
Tensor
<
TInWei
>&
wei_k_c_y_x
,
const
Tensor
<
TInWei
>&
wei_k_c
0
_y_x
_c1
,
Tensor
<
TOut
>&
out_n_k_ho_wo
,
Tensor
<
TOut
>&
out_n_k
0
_ho_wo
_k1
,
ck
::
index_t
nrepeat
)
ck
::
index_t
nrepeat
)
{
{
using
namespace
ck
;
using
namespace
ck
;
...
@@ -36,43 +35,22 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
...
@@ -36,43 +35,22 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
const
auto
N
=
out_n_k_ho_wo_lengths
[
I0
];
const
auto
N
=
out_n_k0_ho_wo_k1_lengths
[
I0
];
const
auto
K
=
out_n_k_ho_wo_lengths
[
I1
];
const
auto
K0
=
out_n_k0_ho_wo_k1_lengths
[
I1
];
const
auto
C
=
wei_k_c_y_x_lengths
[
I1
];
const
auto
Ho
=
out_n_k0_ho_wo_k1_lengths
[
I2
];
const
auto
Wo
=
out_n_k0_ho_wo_k1_lengths
[
I3
];
const
auto
K1
=
out_n_k0_ho_wo_k1_lengths
[
I4
];
const
auto
Hi
=
in_n_c_hi_wi_lengths
[
I2
];
const
auto
C0
=
in_n_c0_hi_wi_c1_lengths
[
I1
];
const
auto
Wi
=
in_n_c_hi_wi_lengths
[
I3
];
const
auto
Hi
=
in_n_c0_hi_wi_c1_lengths
[
I2
];
const
auto
Wi
=
in_n_c0_hi_wi_c1_lengths
[
I3
];
const
auto
C1
=
in_n_c0_hi_wi_c1_lengths
[
I4
];
const
auto
Ho
=
out_n_k_ho_wo_lengths
[
I2
];
const
auto
K
=
wei_k_c0_y_x_c1_lengths
[
I0
];
const
auto
Wo
=
out_n_k_ho_wo_lengths
[
I3
];
const
auto
Y
=
wei_k_c0_y_x_c1_lengths
[
I2
];
const
auto
X
=
wei_k_c0_y_x_c1_lengths
[
I3
];
const
auto
Y
=
wei_k_c_y_x_lengths
[
I2
];
const
auto
X
=
wei_k_c_y_x_lengths
[
I3
];
const
auto
C0
=
C
/
Number
<
InWeiVectorSize
>
{};
const
auto
C1
=
Number
<
InWeiVectorSize
>
{};
const
auto
K0
=
K
/
Number
<
InWeiVectorSize
>
{};
const
auto
K1
=
Number
<
InWeiVectorSize
>
{};
Tensor
<
TInWei
>
in_n_c0_hi_wi_c1
(
HostTensorDescriptor
(
std
::
initializer_list
<
index_t
>
{
N
,
C0
,
Hi
,
Wi
,
C1
}));
Tensor
<
TInWei
>
wei_k_c0_y_x_c1
(
HostTensorDescriptor
(
std
::
initializer_list
<
index_t
>
{
K
,
C0
,
Y
,
X
,
C1
}));
Tensor
<
TOut
>
out_n_k0_ho_wo_k1
(
HostTensorDescriptor
(
std
::
initializer_list
<
index_t
>
{
N
,
K0
,
Ho
,
Wo
,
K1
}));
auto
f_nchw2nc0hwc1
=
[
&
](
auto
n
,
auto
hi
,
auto
wi
,
auto
c
)
{
in_n_c0_hi_wi_c1
(
n
,
c
/
C1
,
hi
,
wi
,
c
%
C1
)
=
in_n_c_hi_wi
(
n
,
c
,
hi
,
wi
);
};
auto
f_kcyx2kc0yxc1
=
[
&
](
auto
k
,
auto
y
,
auto
x
,
auto
c
)
{
wei_k_c0_y_x_c1
(
k
,
c
/
C1
,
y
,
x
,
c
%
C1
)
=
wei_k_c_y_x
(
k
,
c
,
y
,
x
);
};
make_ParallelTensorFunctor
(
f_nchw2nc0hwc1
,
N
,
Hi
,
Wi
,
C
)();
make_ParallelTensorFunctor
(
f_kcyx2kc0yxc1
,
K
,
Y
,
X
,
C
)();
DeviceMem
in_n_c0_hi_wi_c1_device_buf
(
sizeof
(
TInWei
)
*
DeviceMem
in_n_c0_hi_wi_c1_device_buf
(
sizeof
(
TInWei
)
*
in_n_c0_hi_wi_c1
.
mDesc
.
GetElementSpace
());
in_n_c0_hi_wi_c1
.
mDesc
.
GetElementSpace
());
...
@@ -83,13 +61,6 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
...
@@ -83,13 +61,6 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
in_n_c0_hi_wi_c1_device_buf
.
ToDevice
(
in_n_c0_hi_wi_c1
.
mData
.
data
());
in_n_c0_hi_wi_c1_device_buf
.
ToDevice
(
in_n_c0_hi_wi_c1
.
mData
.
data
());
wei_k_c0_y_x_c1_device_buf
.
ToDevice
(
wei_k_c0_y_x_c1
.
mData
.
data
());
wei_k_c0_y_x_c1_device_buf
.
ToDevice
(
wei_k_c0_y_x_c1
.
mData
.
data
());
const
auto
in_n_c0_hi_wi_c1_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
C0
,
Hi
,
Wi
,
I1
));
const
auto
wei_k_c0_y_x_c1_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C0
,
Y
,
X
,
I1
));
const
auto
out_n_k0_ho_wo_k1_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
));
#if 0
#if 0
constexpr index_t BlockSize = 256;
constexpr index_t BlockSize = 256;
...
@@ -144,8 +115,17 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
...
@@ -144,8 +115,17 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
constexpr
index_t
CThreadTransferDstScalarPerVector_K
=
K1
;
constexpr
index_t
CThreadTransferDstScalarPerVector_K
=
K1
;
#endif
#endif
constexpr
index_t
InWeiVectorSize
=
C1
;
const
auto
in_n_c0_hi_wi_c1_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
C0
,
Hi
,
Wi
,
Number
<
C1
/
InWeiVectorSize
>
{}));
const
auto
wei_k_c0_y_x_c1_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C0
,
Y
,
X
,
Number
<
C1
/
InWeiVectorSize
>
{}));
const
auto
out_n_k0_ho_wo_k1_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
));
constexpr
auto
conv_driver
=
constexpr
auto
conv_driver
=
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad
<
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc
0
hw
c1
_kc
0
yx
c1
_nk
0
hw
k1
_outpad
<
BlockSize
,
BlockSize
,
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
,
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
,
TAcc
,
TAcc
,
...
@@ -188,7 +168,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
...
@@ -188,7 +168,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
nrepeat
);
nrepeat
);
{
{
float
perf
=
static_cast
<
float
>
(
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
)
/
float
perf
=
static_cast
<
float
>
(
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
0
*
C1
*
Y
*
X
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
...
@@ -197,10 +177,4 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
...
@@ -197,10 +177,4 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
}
}
out_n_k0_ho_wo_k1_device_buf
.
FromDevice
(
out_n_k0_ho_wo_k1
.
mData
.
data
());
out_n_k0_ho_wo_k1_device_buf
.
FromDevice
(
out_n_k0_ho_wo_k1
.
mData
.
data
());
auto
f_nk0hwk1_to_nkhw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
out_n_k_ho_wo
(
n
,
k
,
ho
,
wo
)
=
out_n_k0_ho_wo_k1
(
n
,
k
/
K1
,
ho
,
wo
,
k
%
K1
);
};
make_ParallelTensorFunctor
(
f_nk0hwk1_to_nkhw
,
N
,
K
,
Ho
,
Wo
)();
}
}
host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp
View file @
5ce317cb
...
@@ -4,16 +4,8 @@
...
@@ -4,16 +4,8 @@
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
ABType
,
template
<
typename
ABType
,
typename
AccType
,
typename
CType
>
typename
AccType
,
void
device_gemm_xdlops_km_kn_mn
(
const
Tensor
<
ABType
>&
a_k_m
,
typename
CType
,
typename
ADesc
,
typename
BDesc
,
typename
CDesc
>
void
device_gemm_xdlops_km_kn_mn
(
const
ADesc
&
a_k_m_grid_desc
,
const
BDesc
&
b_k_n_grid_desc
,
const
CDesc
&
c_m_n_grid_desc
,
const
Tensor
<
ABType
>&
a_k_m
,
const
Tensor
<
ABType
>&
b_k_n
,
const
Tensor
<
ABType
>&
b_k_n
,
Tensor
<
CType
>&
c_m_n
,
Tensor
<
CType
>&
c_m_n
,
ck
::
index_t
nrepeat
)
ck
::
index_t
nrepeat
)
...
@@ -22,9 +14,6 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc,
...
@@ -22,9 +14,6 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc,
std
::
cout
<<
__func__
<<
std
::
endl
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
DeviceMem
a_k_m_device_buf
(
sizeof
(
ABType
)
*
a_k_m
.
mDesc
.
GetElementSpace
());
DeviceMem
a_k_m_device_buf
(
sizeof
(
ABType
)
*
a_k_m
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
ABType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
ABType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CType
)
*
c_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CType
)
*
c_m_n
.
mDesc
.
GetElementSpace
());
...
@@ -60,9 +49,121 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc,
...
@@ -60,9 +49,121 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc,
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 4], C = 128, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
1
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
1
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
// [M, N, K0, K1] = [256, 128, 4, 8]
, C = 128,
for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
256
;
constexpr
index_t
MPerBlock
=
256
;
...
@@ -88,46 +189,185 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc,
...
@@ -88,46 +189,185 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc,
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
1
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
1
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#endif
#endif
const
auto
K
=
a_k_m
_grid_d
esc
.
GetLength
(
I0
)
;
const
auto
K
=
a_k_m
.
mD
esc
.
GetLength
s
()[
0
]
;
const
auto
M
=
a_k_m
_grid_d
esc
.
GetLength
(
I1
)
;
const
auto
M
=
a_k_m
.
mD
esc
.
GetLength
s
()[
1
]
;
const
auto
N
=
b_k_n
_grid_d
esc
.
GetLength
(
I1
)
;
const
auto
N
=
b_k_n
.
mD
esc
.
GetLength
s
()[
1
]
;
constexpr
auto
K1Number
=
Number
<
K1
>
{};
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
const
auto
K0
=
K
/
K1Number
;
const
auto
a_k0_m_k1_grid_desc
=
const
auto
a_k0_m_k1_grid_desc
=
transform_tensor_descriptor
(
a_k_m_grid_desc
,
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_tuple
(
K1Number
*
a_k_m
.
mDesc
.
GetStrides
()[
0
],
make_pass_through_transform
(
M
)),
a_k_m
.
mDesc
.
GetStrides
()[
1
],
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
a_k_m
.
mDesc
.
GetStrides
()[
0
]));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
const
auto
b_k0_n_k1_grid_desc
=
const
auto
b_k0_n_k1_grid_desc
=
transform_tensor_descriptor
(
b_k_n_grid_desc
,
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_tuple
(
K1Number
*
b_k_n
.
mDesc
.
GetStrides
()[
0
],
make_pass_through_transform
(
N
)),
b_k_n
.
mDesc
.
GetStrides
()[
1
],
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
b_k_n
.
mDesc
.
GetStrides
()[
0
]));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
const
auto
c_m_n_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_m_n
.
mDesc
.
GetStrides
()[
0
],
c_m_n
.
mDesc
.
GetStrides
()[
1
]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: M
Sequence
<
0
,
0
,
0
>
{},
// 1+: M
Sequence
<
0
>
{}),
// 2+: K1
Sequence
<
0
,
0
,
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: M
Sequence
<
0
,
0
,
0
>
{},
// 1-: M
Sequence
<
0
>
{}));
// 2-: K1
Sequence
<
0
,
0
,
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
Sequence
<
0
>
{},
// 1+: N
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: K0
Sequence
<
0
>
{}),
// 2+: K1
Sequence
<
0
,
0
,
0
>
{},
// 1+: N
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
,
0
,
0
>
{}),
// 2+: K1
Sequence
<
0
>
{},
// 1-: N
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: K0
Sequence
<
0
>
{}));
// 2-: K1
Sequence
<
0
,
0
,
0
>
{},
// 1-: N
Sequence
<
0
,
0
,
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
...
@@ -147,9 +387,9 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc,
...
@@ -147,9 +387,9 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc,
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
{
...
@@ -194,13 +434,17 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc,
...
@@ -194,13 +434,17 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc,
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
false
,
// CAccessOrderMRepeatNRepeat
true
,
// ABlockLdsExtraM
true
// BBlockLdsExtraN
>
(
static_cast
<
ABType
*>
(
a_k_m_device_buf
.
GetDeviceBuffer
()),
>
(
static_cast
<
ABType
*>
(
a_k_m_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
a_k0_m_k1_grid_desc
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
c_m_n_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
a_k0_m_k1_grid_step_hacks
,
a_k0_m_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
...
...
host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp
0 → 100644
View file @
5ce317cb
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
ABType
,
typename
AccType
,
typename
CType
>
void
device_gemm_xdlops_km_kn_nm
(
const
Tensor
<
ABType
>&
a_k_m
,
const
Tensor
<
ABType
>&
b_k_n
,
Tensor
<
CType
>&
c_n_m
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
DeviceMem
a_k_m_device_buf
(
sizeof
(
ABType
)
*
a_k_m
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
ABType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_n_m_device_buf
(
sizeof
(
CType
)
*
c_n_m
.
mDesc
.
GetElementSpace
());
a_k_m_device_buf
.
ToDevice
(
a_k_m
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
c_n_m_device_buf
.
ToDevice
(
c_n_m
.
mData
.
data
());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
256
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#endif
const
auto
K
=
a_k_m
.
mDesc
.
GetLengths
()[
0
];
const
auto
M
=
a_k_m
.
mDesc
.
GetLengths
()[
1
];
const
auto
N
=
b_k_n
.
mDesc
.
GetLengths
()[
1
];
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
const
auto
a_k0_m_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
K1Number
*
a_k_m
.
mDesc
.
GetStrides
()[
0
],
a_k_m
.
mDesc
.
GetStrides
()[
1
],
a_k_m
.
mDesc
.
GetStrides
()[
0
]));
const
auto
b_k0_n_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
K1Number
*
b_k_n
.
mDesc
.
GetStrides
()[
0
],
b_k_n
.
mDesc
.
GetStrides
()[
1
],
b_k_n
.
mDesc
.
GetStrides
()[
0
]));
const
auto
c_m_n_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_n_m
.
mDesc
.
GetStrides
()[
1
],
c_n_m
.
mDesc
.
GetStrides
()[
0
]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: M
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: M
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: N
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: N
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
ABType
,
AccType
,
CType
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
c_m_n_grid_desc
),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
ABlockTransferSrcScalarPerVector_M
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
BBlockTransferSrcScalarPerVector_N
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
6
,
CThreadTransferDstScalarPerVector
,
decltype
(
a_k0_m_k1_grid_step_hacks
),
decltype
(
b_k0_n_k1_grid_step_hacks
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
ABType
*>
(
a_k_m_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_n_m_device_buf
.
GetDeviceBuffer
()),
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
a_k0_m_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
a_k0_m_k1_grid_move_slice_window_step_hacks
,
b_k0_n_k1_grid_move_slice_window_step_hacks
,
nrepeat
);
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
M
*
N
*
K
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
// copy result back to host
c_n_m_device_buf
.
FromDevice
(
c_n_m
.
mData
.
data
());
}
host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp
View file @
5ce317cb
...
@@ -4,16 +4,8 @@
...
@@ -4,16 +4,8 @@
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
ABType
,
template
<
typename
ABType
,
typename
AccType
,
typename
CType
>
typename
AccType
,
void
device_gemm_xdlops_km_nk_mn
(
const
Tensor
<
ABType
>&
a_k_m
,
typename
CType
,
typename
ADesc
,
typename
BDesc
,
typename
CDesc
>
void
device_gemm_xdlops_km_nk_mn
(
const
ADesc
&
a_k_m_grid_desc
,
const
BDesc
&
b_n_k_grid_desc
,
const
CDesc
&
c_m_n_grid_desc
,
const
Tensor
<
ABType
>&
a_k_m
,
const
Tensor
<
ABType
>&
b_n_k
,
const
Tensor
<
ABType
>&
b_n_k
,
Tensor
<
CType
>&
c_m_n
,
Tensor
<
CType
>&
c_m_n
,
ck
::
index_t
nrepeat
)
ck
::
index_t
nrepeat
)
...
@@ -22,9 +14,6 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc,
...
@@ -22,9 +14,6 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc,
std
::
cout
<<
__func__
<<
std
::
endl
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
DeviceMem
a_k_m_device_buf
(
sizeof
(
ABType
)
*
a_k_m
.
mDesc
.
GetElementSpace
());
DeviceMem
a_k_m_device_buf
(
sizeof
(
ABType
)
*
a_k_m
.
mDesc
.
GetElementSpace
());
DeviceMem
b_n_k_device_buf
(
sizeof
(
ABType
)
*
b_n_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_n_k_device_buf
(
sizeof
(
ABType
)
*
b_n_k
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CType
)
*
c_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CType
)
*
c_m_n
.
mDesc
.
GetElementSpace
());
...
@@ -60,9 +49,121 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc,
...
@@ -60,9 +49,121 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc,
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
1
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
// [M, N, K0, K1] = [256, 128, 4, 8]
, C = 128,
for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
256
;
constexpr
index_t
MPerBlock
=
256
;
...
@@ -88,46 +189,185 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc,
...
@@ -88,46 +189,185 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc,
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
1
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#endif
#endif
const
auto
K
=
a_k_m
_grid_d
esc
.
GetLength
(
I0
)
;
const
auto
K
=
a_k_m
.
mD
esc
.
GetLength
s
()[
0
]
;
const
auto
M
=
a_k_m
_grid_d
esc
.
GetLength
(
I1
)
;
const
auto
M
=
a_k_m
.
mD
esc
.
GetLength
s
()[
1
]
;
const
auto
N
=
b_n_k
_grid_d
esc
.
GetLength
(
I0
)
;
const
auto
N
=
b_n_k
.
mD
esc
.
GetLength
s
()[
0
]
;
constexpr
auto
K1Number
=
Number
<
K1
>
{};
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
const
auto
K0
=
K
/
K1Number
;
const
auto
a_k0_m_k1_grid_desc
=
const
auto
a_k0_m_k1_grid_desc
=
transform_tensor_descriptor
(
a_k_m_grid_desc
,
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_tuple
(
K1Number
*
a_k_m
.
mDesc
.
GetStrides
()[
0
],
make_pass_through_transform
(
M
)),
a_k_m
.
mDesc
.
GetStrides
()[
1
],
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
a_k_m
.
mDesc
.
GetStrides
()[
0
]));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
const
auto
b_k0_n_k1_grid_desc
=
const
auto
b_k0_n_k1_grid_desc
=
transform_tensor_descriptor
(
b_n_k_grid_desc
,
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
K1Number
*
b_n_k
.
mDesc
.
GetStrides
()[
1
],
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
))),
b_n_k
.
mDesc
.
GetStrides
()[
0
],
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
b_n_k
.
mDesc
.
GetStrides
()[
1
]));
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
const
auto
c_m_n_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_m_n
.
mDesc
.
GetStrides
()[
0
],
c_m_n
.
mDesc
.
GetStrides
()[
1
]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: M
Sequence
<
0
,
0
,
0
>
{},
// 1+: M
Sequence
<
0
>
{}),
// 2+: K1
Sequence
<
0
,
0
,
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: M
Sequence
<
0
,
0
,
0
>
{},
// 1-: M
Sequence
<
0
>
{}));
// 2-: K1
Sequence
<
0
,
0
,
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
Sequence
<
0
>
{},
// 1+: N
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: K0
Sequence
<
0
>
{}),
// 2+: K1
Sequence
<
0
,
0
,
0
>
{},
// 1+: N
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
,
0
,
0
>
{}),
// 2+: K1
Sequence
<
0
>
{},
// 1-: N
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: K0
Sequence
<
0
>
{}));
// 2-: K1
Sequence
<
0
,
0
,
0
>
{},
// 1-: N
Sequence
<
0
,
0
,
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
...
@@ -147,9 +387,9 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc,
...
@@ -147,9 +387,9 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc,
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
{
...
@@ -194,13 +434,17 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc,
...
@@ -194,13 +434,17 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc,
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
false
,
// CAccessOrderMRepeatNRepeat
true
,
// ABlockLdsExtraM
true
// BBlockLdsExtraN
>
(
static_cast
<
ABType
*>
(
a_k_m_device_buf
.
GetDeviceBuffer
()),
>
(
static_cast
<
ABType
*>
(
a_k_m_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_n_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_n_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
a_k0_m_k1_grid_desc
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
c_m_n_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
a_k0_m_k1_grid_step_hacks
,
a_k0_m_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
...
...
host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp
0 → 100644
View file @
5ce317cb
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
ABType
,
typename
AccType
,
typename
CType
>
void
device_gemm_xdlops_km_nk_nm
(
const
Tensor
<
ABType
>&
a_k_m
,
const
Tensor
<
ABType
>&
b_n_k
,
Tensor
<
CType
>&
c_n_m
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
DeviceMem
a_k_m_device_buf
(
sizeof
(
ABType
)
*
a_k_m
.
mDesc
.
GetElementSpace
());
DeviceMem
b_n_k_device_buf
(
sizeof
(
ABType
)
*
b_n_k
.
mDesc
.
GetElementSpace
());
DeviceMem
c_n_m_device_buf
(
sizeof
(
CType
)
*
c_n_m
.
mDesc
.
GetElementSpace
());
a_k_m_device_buf
.
ToDevice
(
a_k_m
.
mData
.
data
());
b_n_k_device_buf
.
ToDevice
(
b_n_k
.
mData
.
data
());
c_n_m_device_buf
.
ToDevice
(
c_n_m
.
mData
.
data
());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
2
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
256
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#endif
const
auto
K
=
a_k_m
.
mDesc
.
GetLengths
()[
0
];
const
auto
M
=
a_k_m
.
mDesc
.
GetLengths
()[
1
];
const
auto
N
=
b_n_k
.
mDesc
.
GetLengths
()[
0
];
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
const
auto
a_k0_m_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
K1Number
*
a_k_m
.
mDesc
.
GetStrides
()[
0
],
a_k_m
.
mDesc
.
GetStrides
()[
1
],
a_k_m
.
mDesc
.
GetStrides
()[
0
]));
const
auto
b_k0_n_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
K1Number
*
b_n_k
.
mDesc
.
GetStrides
()[
1
],
b_n_k
.
mDesc
.
GetStrides
()[
0
],
b_n_k
.
mDesc
.
GetStrides
()[
1
]));
const
auto
c_m_n_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_n_m
.
mDesc
.
GetStrides
()[
1
],
c_n_m
.
mDesc
.
GetStrides
()[
0
]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: M
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: M
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: N
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: N
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
ABType
,
AccType
,
CType
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
c_m_n_grid_desc
),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
ABlockTransferSrcScalarPerVector_M
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
BBlockTransferSrcScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
6
,
CThreadTransferDstScalarPerVector
,
decltype
(
a_k0_m_k1_grid_step_hacks
),
decltype
(
b_k0_n_k1_grid_step_hacks
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
ABType
*>
(
a_k_m_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_n_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_n_m_device_buf
.
GetDeviceBuffer
()),
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
a_k0_m_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
a_k0_m_k1_grid_move_slice_window_step_hacks
,
b_k0_n_k1_grid_move_slice_window_step_hacks
,
nrepeat
);
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
M
*
N
*
K
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
// copy result back to host
c_n_m_device_buf
.
FromDevice
(
c_n_m
.
mData
.
data
());
}
host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp
View file @
5ce317cb
...
@@ -4,16 +4,8 @@
...
@@ -4,16 +4,8 @@
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
ABType
,
template
<
typename
ABType
,
typename
AccType
,
typename
CType
>
typename
AccType
,
void
device_gemm_xdlops_mk_kn_mn
(
const
Tensor
<
ABType
>&
a_m_k
,
typename
CType
,
typename
ADesc
,
typename
BDesc
,
typename
CDesc
>
void
device_gemm_xdlops_mk_kn_mn
(
const
ADesc
&
a_m_k_grid_desc
,
const
BDesc
&
b_k_n_grid_desc
,
const
CDesc
&
c_m_n_grid_desc
,
const
Tensor
<
ABType
>&
a_m_k
,
const
Tensor
<
ABType
>&
b_k_n
,
const
Tensor
<
ABType
>&
b_k_n
,
Tensor
<
CType
>&
c_m_n
,
Tensor
<
CType
>&
c_m_n
,
ck
::
index_t
nrepeat
)
ck
::
index_t
nrepeat
)
...
@@ -22,9 +14,6 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc,
...
@@ -22,9 +14,6 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc,
std
::
cout
<<
__func__
<<
std
::
endl
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
DeviceMem
a_m_k_device_buf
(
sizeof
(
ABType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
a_m_k_device_buf
(
sizeof
(
ABType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
ABType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
ABType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CType
)
*
c_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CType
)
*
c_m_n
.
mDesc
.
GetElementSpace
());
...
@@ -33,8 +22,148 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc,
...
@@ -33,8 +22,148 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc,
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
c_m_n_device_buf
.
ToDevice
(
c_m_n
.
mData
.
data
());
c_m_n_device_buf
.
ToDevice
(
c_m_n
.
mData
.
data
());
#if 1
#if 0
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
1
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
1
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
256
;
constexpr
index_t
MPerBlock
=
256
;
...
@@ -88,46 +217,157 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc,
...
@@ -88,46 +217,157 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc,
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
1
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
1
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#endif
#endif
const
auto
K
=
a_m_k
_grid_d
esc
.
GetLength
(
I1
)
;
const
auto
K
=
a_m_k
.
mD
esc
.
GetLength
s
()[
1
]
;
const
auto
M
=
a_m_k
_grid_d
esc
.
GetLength
(
I0
)
;
const
auto
M
=
a_m_k
.
mD
esc
.
GetLength
s
()[
0
]
;
const
auto
N
=
b_k_n
_grid_d
esc
.
GetLength
(
I1
)
;
const
auto
N
=
b_k_n
.
mD
esc
.
GetLength
s
()[
1
]
;
constexpr
auto
K1Number
=
Number
<
K1
>
{};
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
const
auto
K0
=
K
/
K1Number
;
const
auto
a_k0_m_k1_grid_desc
=
const
auto
a_k0_m_k1_grid_desc
=
transform_tensor_descriptor
(
a_m_k_grid_desc
,
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
make_pass_through_transform
(
M
),
make_tuple
(
K1Number
*
a_m_k
.
mDesc
.
GetStrides
()[
1
],
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
))),
a_m_k
.
mDesc
.
GetStrides
()[
0
],
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
a_m_k
.
mDesc
.
GetStrides
()[
1
]));
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
const
auto
b_k0_n_k1_grid_desc
=
const
auto
b_k0_n_k1_grid_desc
=
transform_tensor_descriptor
(
b_k_n_grid_desc
,
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_tuple
(
K1Number
*
b_k_n
.
mDesc
.
GetStrides
()[
0
],
make_pass_through_transform
(
N
)),
b_k_n
.
mDesc
.
GetStrides
()[
1
],
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
b_k_n
.
mDesc
.
GetStrides
()[
0
]));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
const
auto
c_m_n_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_m_n
.
mDesc
.
GetStrides
()[
0
],
c_m_n
.
mDesc
.
GetStrides
()[
1
]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: M
Sequence
<
0
,
0
,
0
>
{},
// 1+: M
Sequence
<
0
>
{}),
// 2+: K1
Sequence
<
0
,
0
,
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: M
Sequence
<
0
,
0
,
0
>
{},
// 1-: M
Sequence
<
0
>
{}));
// 2-: K1
Sequence
<
0
,
0
,
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
Sequence
<
0
>
{},
// 1+: N
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: K0
Sequence
<
0
>
{}),
// 2+: K1
Sequence
<
0
,
0
,
0
>
{},
// 1+: N
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
,
0
,
0
>
{}),
// 2+: K1
Sequence
<
0
>
{},
// 1-: N
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: K0
Sequence
<
0
>
{}));
// 2-: K1
Sequence
<
0
,
0
,
0
>
{},
// 1-: N
Sequence
<
0
,
0
,
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
...
@@ -147,9 +387,9 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc,
...
@@ -147,9 +387,9 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc,
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
{
...
@@ -194,13 +434,17 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc,
...
@@ -194,13 +434,17 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc,
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
false
,
// CAccessOrderMRepeatNRepeat
true
,
// ABlockLdsExtraM
true
// BBlockLdsExtraN
>
(
static_cast
<
ABType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
>
(
static_cast
<
ABType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
a_k0_m_k1_grid_desc
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
c_m_n_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
a_k0_m_k1_grid_step_hacks
,
a_k0_m_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
...
...
host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp
0 → 100644
View file @
5ce317cb
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
ABType
,
typename
AccType
,
typename
CType
>
void
device_gemm_xdlops_mk_kn_nm
(
const
Tensor
<
ABType
>&
a_m_k
,
const
Tensor
<
ABType
>&
b_k_n
,
Tensor
<
CType
>&
c_n_m
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
DeviceMem
a_m_k_device_buf
(
sizeof
(
ABType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
ABType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_n_m_device_buf
(
sizeof
(
CType
)
*
c_n_m
.
mDesc
.
GetElementSpace
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
c_n_m_device_buf
.
ToDevice
(
c_n_m
.
mData
.
data
());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_N = 2;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
256
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#endif
const
auto
K
=
a_m_k
.
mDesc
.
GetLengths
()[
1
];
const
auto
M
=
a_m_k
.
mDesc
.
GetLengths
()[
0
];
const
auto
N
=
b_k_n
.
mDesc
.
GetLengths
()[
1
];
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
const
auto
a_k0_m_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
K1Number
*
a_m_k
.
mDesc
.
GetStrides
()[
1
],
a_m_k
.
mDesc
.
GetStrides
()[
0
],
a_m_k
.
mDesc
.
GetStrides
()[
1
]));
const
auto
b_k0_n_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
K1Number
*
b_k_n
.
mDesc
.
GetStrides
()[
0
],
b_k_n
.
mDesc
.
GetStrides
()[
1
],
b_k_n
.
mDesc
.
GetStrides
()[
0
]));
const
auto
c_m_n_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_n_m
.
mDesc
.
GetStrides
()[
1
],
c_n_m
.
mDesc
.
GetStrides
()[
0
]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: M
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: M
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: N
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: N
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
ABType
,
AccType
,
CType
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
c_m_n_grid_desc
),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
ABlockTransferSrcScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
BBlockTransferSrcScalarPerVector_N
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
6
,
CThreadTransferDstScalarPerVector
,
decltype
(
a_k0_m_k1_grid_step_hacks
),
decltype
(
b_k0_n_k1_grid_step_hacks
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
ABType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_n_m_device_buf
.
GetDeviceBuffer
()),
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
a_k0_m_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
a_k0_m_k1_grid_move_slice_window_step_hacks
,
b_k0_n_k1_grid_move_slice_window_step_hacks
,
nrepeat
);
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
M
*
N
*
K
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
// copy result back to host
c_n_m_device_buf
.
FromDevice
(
c_n_m
.
mData
.
data
());
}
host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp
View file @
5ce317cb
...
@@ -4,16 +4,8 @@
...
@@ -4,16 +4,8 @@
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
ABType
,
template
<
typename
ABType
,
typename
AccType
,
typename
CType
>
typename
AccType
,
void
device_gemm_xdlops_mk_nk_mn
(
const
Tensor
<
ABType
>&
a_m_k
,
typename
CType
,
typename
ADesc
,
typename
BDesc
,
typename
CDesc
>
void
device_gemm_xdlops_mk_nk_mn
(
const
ADesc
&
a_m_k_grid_desc
,
const
BDesc
&
b_n_k_grid_desc
,
const
CDesc
&
c_m_n_grid_desc
,
const
Tensor
<
ABType
>&
a_m_k
,
const
Tensor
<
ABType
>&
b_n_k
,
const
Tensor
<
ABType
>&
b_n_k
,
Tensor
<
CType
>&
c_m_n
,
Tensor
<
CType
>&
c_m_n
,
ck
::
index_t
nrepeat
)
ck
::
index_t
nrepeat
)
...
@@ -22,9 +14,6 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
...
@@ -22,9 +14,6 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
std
::
cout
<<
__func__
<<
std
::
endl
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
DeviceMem
a_m_k_device_buf
(
sizeof
(
ABType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
a_m_k_device_buf
(
sizeof
(
ABType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_n_k_device_buf
(
sizeof
(
ABType
)
*
b_n_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_n_k_device_buf
(
sizeof
(
ABType
)
*
b_n_k
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CType
)
*
c_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CType
)
*
c_m_n
.
mDesc
.
GetElementSpace
());
...
@@ -34,6 +23,34 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
...
@@ -34,6 +23,34 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
c_m_n_device_buf
.
ToDevice
(
c_m_n
.
mData
.
data
());
c_m_n_device_buf
.
ToDevice
(
c_m_n
.
mData
.
data
());
#if 0
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
@@ -60,9 +77,93 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
...
@@ -60,9 +77,93 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
1
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
// [M, N, K0, K1] = [256, 128, 4, 8]
, C = 128,
for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
256
;
constexpr
index_t
MPerBlock
=
256
;
...
@@ -90,7 +191,7 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
...
@@ -90,7 +191,7 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
// [M, N, K0, K1] = [128, 256, 4, 8]
, C = 128,
for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
MPerBlock
=
128
;
...
@@ -117,8 +218,36 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
...
@@ -117,8 +218,36 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
MPerBlock
=
128
;
...
@@ -144,46 +273,131 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
...
@@ -144,46 +273,131 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [64, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
1
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#endif
#endif
const
auto
K
=
a_m_k
_grid_d
esc
.
GetLength
(
I1
)
;
const
auto
K
=
a_m_k
.
mD
esc
.
GetLength
s
()[
1
]
;
const
auto
M
=
a_m_k
_grid_d
esc
.
GetLength
(
I0
)
;
const
auto
M
=
a_m_k
.
mD
esc
.
GetLength
s
()[
0
]
;
const
auto
N
=
b_n_k
_grid_d
esc
.
GetLength
(
I0
)
;
const
auto
N
=
b_n_k
.
mD
esc
.
GetLength
s
()[
0
]
;
constexpr
auto
K1Number
=
Number
<
K1
>
{};
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
const
auto
K0
=
K
/
K1Number
;
#if 1
// non-padded GEMM
const
auto
a_k0_m_k1_grid_desc
=
const
auto
a_k0_m_k1_grid_desc
=
transform_tensor_descriptor
(
a_m_k_grid_desc
,
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
make_pass_through_transform
(
M
),
make_tuple
(
K1Number
*
a_m_k
.
mDesc
.
GetStrides
()[
1
],
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
))),
a_m_k
.
mDesc
.
GetStrides
()[
0
],
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
a_m_k
.
mDesc
.
GetStrides
()[
1
]));
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
const
auto
b_k0_n_k1_grid_desc
=
const
auto
b_k0_n_k1_grid_desc
=
transform_tensor_descriptor
(
b_n_k_grid_desc
,
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
K1Number
*
b_n_k
.
mDesc
.
GetStrides
()[
1
],
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
))),
b_n_k
.
mDesc
.
GetStrides
()[
0
],
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
b_n_k
.
mDesc
.
GetStrides
()[
1
]));
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
const
auto
c_m_n_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_m_n
.
mDesc
.
GetStrides
()[
0
],
c_m_n
.
mDesc
.
GetStrides
()[
1
]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: M
Sequence
<
0
,
0
,
0
>
{},
// 1+: M
Sequence
<
0
>
{}),
// 2+: K1
Sequence
<
0
,
0
,
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: M
Sequence
<
0
,
0
,
0
>
{},
// 1-: M
Sequence
<
0
>
{}));
// 2-: K1
Sequence
<
0
,
0
,
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
Sequence
<
0
>
{},
// 1+: N
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: K0
Sequence
<
0
>
{}),
// 2+: K1
Sequence
<
0
,
0
,
0
>
{},
// 1+: N
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
,
0
,
0
>
{}),
// 2+: K1
Sequence
<
0
>
{},
// 1-: N
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: K0
Sequence
<
0
>
{}));
// 2-: K1
Sequence
<
0
,
0
,
0
>
{},
// 1-: N
Sequence
<
0
,
0
,
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
...
@@ -203,9 +417,80 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
...
@@ -203,9 +417,80 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
#else
// padded GEMM
const
auto
a_k0_m_k1_grid_desc_tmp
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
K1Number
*
a_m_k
.
mDesc
.
GetStrides
()[
1
],
a_m_k
.
mDesc
.
GetStrides
()[
0
],
a_m_k
.
mDesc
.
GetStrides
()[
1
]));
const
auto
MRightPad
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
)
*
MPerBlock
-
M
;
const
auto
a_k0_m_k1_grid_desc
=
transform_tensor_descriptor
(
a_k0_m_k1_grid_desc_tmp
,
make_tuple
(
make_pass_through_transform
(
K0
),
make_right_pad_transform
(
M
,
MRightPad
),
make_pass_through_transform
(
K1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
b_k0_n_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
K1Number
*
b_n_k
.
mDesc
.
GetStrides
()[
1
],
b_n_k
.
mDesc
.
GetStrides
()[
0
],
b_n_k
.
mDesc
.
GetStrides
()[
1
]));
const
auto
c_m_n_grid_desc_tmp
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_m_n
.
mDesc
.
GetStrides
()[
0
],
c_m_n
.
mDesc
.
GetStrides
()[
1
]));
const
auto
c_m_n_grid_desc
=
transform_tensor_descriptor
(
c_m_n_grid_desc_tmp
,
make_tuple
(
make_right_pad_transform
(
M
,
MRightPad
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
>
{},
// 0+: K0
Sequence
<
0
,
0
,
0
,
0
>
{},
// 1+: M
Sequence
<
0
,
0
,
0
,
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
>
{},
// 0-: K0
Sequence
<
0
,
0
,
0
,
0
>
{},
// 1-: M
Sequence
<
0
,
0
,
0
,
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: N
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: N
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
#endif
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
{
...
@@ -250,13 +535,17 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
...
@@ -250,13 +535,17 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc,
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
false
,
// CAccessOrderMRepeatNRepeat
true
,
// ABlockLdsExtraM
true
// BBlockLdsExtraN
>
(
static_cast
<
ABType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
>
(
static_cast
<
ABType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_n_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_n_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
a_k0_m_k1_grid_desc
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
c_m_n_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
a_k0_m_k1_grid_step_hacks
,
a_k0_m_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
...
...
host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp
0 → 100644
View file @
5ce317cb
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
ABType
,
typename
AccType
,
typename
CType
>
void
device_gemm_xdlops_mk_nk_nm
(
const
Tensor
<
ABType
>&
a_m_k
,
const
Tensor
<
ABType
>&
b_n_k
,
Tensor
<
CType
>&
c_n_m
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
DeviceMem
a_m_k_device_buf
(
sizeof
(
ABType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_n_k_device_buf
(
sizeof
(
ABType
)
*
b_n_k
.
mDesc
.
GetElementSpace
());
DeviceMem
c_n_m_device_buf
(
sizeof
(
CType
)
*
c_n_m
.
mDesc
.
GetElementSpace
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_n_k_device_buf
.
ToDevice
(
b_n_k
.
mData
.
data
());
c_n_m_device_buf
.
ToDevice
(
c_n_m
.
mData
.
data
());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 4;
#elif
0
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
4
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
4
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
4
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
4
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 0
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
256
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
64
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
4
;
#endif
const
auto
K
=
a_m_k
.
mDesc
.
GetLengths
()[
1
];
const
auto
M
=
a_m_k
.
mDesc
.
GetLengths
()[
0
];
const
auto
N
=
b_n_k
.
mDesc
.
GetLengths
()[
0
];
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
const
auto
a_k0_m_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
M
,
K1Number
),
make_tuple
(
K1Number
*
a_m_k
.
mDesc
.
GetStrides
()[
1
],
a_m_k
.
mDesc
.
GetStrides
()[
0
],
a_m_k
.
mDesc
.
GetStrides
()[
1
]));
const
auto
b_k0_n_k1_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N
,
K1Number
),
make_tuple
(
K1Number
*
b_n_k
.
mDesc
.
GetStrides
()[
1
],
b_n_k
.
mDesc
.
GetStrides
()[
0
],
b_n_k
.
mDesc
.
GetStrides
()[
1
]));
const
auto
c_m_n_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_n_m
.
mDesc
.
GetStrides
()[
1
],
c_n_m
.
mDesc
.
GetStrides
()[
0
]));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: M
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: M
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
Sequence
<
0
>
{},
// 1+: N
Sequence
<
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
>
{},
// 0-: K0
Sequence
<
0
>
{},
// 1-: N
Sequence
<
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
ABType
,
AccType
,
CType
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
c_m_n_grid_desc
),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
ABlockTransferSrcScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
BBlockTransferSrcScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
6
,
CThreadTransferDstScalarPerVector
,
decltype
(
a_k0_m_k1_grid_step_hacks
),
decltype
(
b_k0_n_k1_grid_step_hacks
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
ABType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_n_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_n_m_device_buf
.
GetDeviceBuffer
()),
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
a_k0_m_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
a_k0_m_k1_grid_move_slice_window_step_hacks
,
b_k0_n_k1_grid_move_slice_window_step_hacks
,
nrepeat
);
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
M
*
N
*
K
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
// copy result back to host
c_n_m_device_buf
.
FromDevice
(
c_n_m
.
mData
.
data
());
}
host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
→
host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nc
0
hw
c1
_kc
0
yx
c1
_nk
0
hw
k1
.hpp
View file @
5ce317cb
#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_HPP
#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC
0
HW
c1
_KC
0
YX
C1
_NK
0
HW
K1
_HPP
#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_HPP
#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC
0
HW
c1
_KC
0
YX
C1
_NK
0
HW
K1
_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
...
@@ -28,7 +28,7 @@ template <ck::index_t BlockSize,
...
@@ -28,7 +28,7 @@ template <ck::index_t BlockSize,
ck
::
index_t
BThreadTransferSrcScalarPerVector_E2
,
ck
::
index_t
BThreadTransferSrcScalarPerVector_E2
,
ck
::
index_t
CThreadTransferDstScalarPerVector_K
,
ck
::
index_t
CThreadTransferDstScalarPerVector_K
,
ck
::
index_t
activ_type
>
ck
::
index_t
activ_type
>
struct
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad
struct
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc
0
hw
c1
_kc
0
yx
c1
_nk
0
hw
k1
_outpad
{
{
template
<
typename
...
Wei
,
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
In
,
...
...
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
View file @
5ce317cb
#ifndef DRIVER_GEMM_XDLOPS_V2R3
#ifndef DRIVER_GEMM_XDLOPS_V2R3
_HPP
#define DRIVER_GEMM_XDLOPS_V2R3
#define DRIVER_GEMM_XDLOPS_V2R3
_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
...
@@ -46,13 +46,17 @@ template <ck::index_t BlockSize,
...
@@ -46,13 +46,17 @@ template <ck::index_t BlockSize,
typename
CGridStepHacks
,
typename
CGridStepHacks
,
typename
AGridMoveSliceWindowStepHacks
,
typename
AGridMoveSliceWindowStepHacks
,
typename
BGridMoveSliceWindowStepHacks
,
typename
BGridMoveSliceWindowStepHacks
,
bool
CAccessOrderMRepeatNRepeat
>
bool
CAccessOrderMRepeatNRepeat
,
bool
ABlockLdsAddExtraM
,
bool
BBlockLdsAddExtraN
>
__host__
float
driver_gemm_xdlops_v2r3
(
const
FloatAB
*
p_a_grid
,
__host__
float
driver_gemm_xdlops_v2r3
(
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_grid
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
FloatC
*
p_c_grid
,
const
AK0MK1GridDesc
&
a_k0_m_k1_grid_desc
,
const
AK0MK1GridDesc
&
a_k0_m_k1_grid_desc
,
const
BK0NK1GridDesc
&
b_k0_n_k1_grid_desc
,
const
BK0NK1GridDesc
&
b_k0_n_k1_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
ck
::
index_t
M01
,
ck
::
index_t
N01
,
AGridStepHacks
,
AGridStepHacks
,
BGridStepHacks
,
BGridStepHacks
,
CGridStepHacks
,
CGridStepHacks
,
...
@@ -108,7 +112,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
...
@@ -108,7 +112,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
CGridStepHacks
,
CGridStepHacks
,
AGridMoveSliceWindowStepHacks
,
AGridMoveSliceWindowStepHacks
,
BGridMoveSliceWindowStepHacks
,
BGridMoveSliceWindowStepHacks
,
CAccessOrderMRepeatNRepeat
>
;
CAccessOrderMRepeatNRepeat
,
ABlockLdsAddExtraM
,
BBlockLdsAddExtraN
>
;
{
{
std
::
cout
<<
"a_k0_m_k1_grid_desc{"
<<
a_k0_m_k1_grid_desc
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"a_k0_m_k1_grid_desc{"
<<
a_k0_m_k1_grid_desc
.
GetLength
(
I0
)
<<
", "
...
@@ -123,7 +129,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
...
@@ -123,7 +129,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
<<
c_m_n_grid_desc
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
<<
c_m_n_grid_desc
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
}
if
(
!
GridwiseGemm
::
CheckValidity
(
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
))
if
(
!
GridwiseGemm
::
CheckValidity
(
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
M01
,
N01
))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"
);
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"
);
...
@@ -134,7 +141,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
...
@@ -134,7 +141,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
const
auto
c_block_cluster_adaptor
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_m_n_grid_desc
);
const
auto
c_block_cluster_adaptor
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_m_n_grid_desc
,
M01
,
N01
);
using
CBlockClusterAdaptor
=
decltype
(
c_block_cluster_adaptor
);
using
CBlockClusterAdaptor
=
decltype
(
c_block_cluster_adaptor
);
...
...
host/driver_offline/src/conv_bwd_driver_offline.cpp
View file @
5ce317cb
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include <stdlib.h>
#include <stdlib.h>
#include <half.hpp>
#include <half.hpp>
#include "config.hpp"
#include "config.hpp"
#include "debug.hpp"
#include "print.hpp"
#include "print.hpp"
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
...
@@ -14,15 +15,16 @@
...
@@ -14,15 +15,16 @@
#include "device_tensor.hpp"
#include "device_tensor.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp"
#define USE_MODE 1
#define USE_MODE 1
#define USE_CONV_BWD_V4R1_XDL_NHWC
1
#define USE_CONV_BWD_V4R1_XDL_NHWC
0
#define USE_CONV_BWD_V4R1R2_XDL_NHWC 1
#define USE_CONV_BWD_V4R1R2_XDL_NHWC 1
enum
ConvBackwardDataAlgo
enum
ConvBackwardDataAlgo
{
{
V4R1XDLNHWC
,
V4R1XDLNHWC
,
// 0
V4R1R2XDLNHWC
,
V4R1R2XDLNHWC
,
// 1
};
};
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
...
@@ -280,20 +282,43 @@ int main(int argc, char* argv[])
...
@@ -280,20 +282,43 @@ int main(int argc, char* argv[])
const
auto
tmp
=
f_make_for_device_nhwc
();
const
auto
tmp
=
f_make_for_device_nhwc
();
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
<
in_data_t
,
if
(
Y
==
1
&&
X
==
1
&&
in_left_pad_h
==
0
&&
in_left_pad_w
==
0
&&
in_right_pad_h
==
0
&&
acc_data_t
,
in_right_pad_w
==
0
)
out_data_t
>
(
{
tmp
[
I0
],
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1
<
tmp
[
I1
],
in_data_t
,
tmp
[
I2
],
acc_data_t
,
tmp
[
I3
],
out_data_t
>
(
tmp
[
I0
],
tmp
[
I4
],
tmp
[
I1
],
tmp
[
I5
],
tmp
[
I2
],
tmp
[
I6
],
tmp
[
I3
],
in_device
,
tmp
[
I4
],
wei
,
tmp
[
I5
],
out
,
tmp
[
I6
],
nrepeat
);
in_device
,
wei
,
out
,
nrepeat
);
}
else
{
#if 1
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
<
in_data_t
,
acc_data_t
,
out_data_t
>
(
tmp
[
I0
],
tmp
[
I1
],
tmp
[
I2
],
tmp
[
I3
],
tmp
[
I4
],
tmp
[
I5
],
tmp
[
I6
],
in_device
,
wei
,
out
,
nrepeat
);
#endif
}
}
}
#endif
#endif
...
...
host/driver_offline/src/conv_fwd_driver_offline.cpp
View file @
5ce317cb
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include <stdlib.h>
#include <stdlib.h>
#include <half.hpp>
#include <half.hpp>
#include "config.hpp"
#include "config.hpp"
#include "debug.hpp"
#include "print.hpp"
#include "print.hpp"
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
...
@@ -15,15 +16,13 @@
...
@@ -15,15 +16,13 @@
#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 0
#define USE_DYNAMIC_MODE 0
#define USE_CONV_FWD_V4R4_NCHW 0
#define USE_CONV_FWD_V4R4_NCHW 0
#define USE_CONV_FWD_V4R4R2_NHWC 0
#define USE_CONV_FWD_V4R4R2_NHWC 1
#define USE_CONV_FWD_V6R1_NCHW 0
#define USE_CONV_FWD_V6R1_NCHW 1
#define USE_CONV_FWD_V5R1_NCHWC 1
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
...
@@ -32,9 +31,8 @@ enum ConvForwardAlgo
...
@@ -32,9 +31,8 @@ enum ConvForwardAlgo
V4R4NCHW
,
// 0
V4R4NCHW
,
// 0
V4R4R2NHWC
,
// 1
V4R4R2NHWC
,
// 1
V6R1NCHW
,
// 2
V6R1NCHW
,
// 2
V5R1NCHWC
,
// 3
V4R4R2XDLNCHW
,
// 3
V4R4R2XDLNCHW
,
// 4
V4R4R4XDLNHWC
// 4
V4R4R4XDLNHWC
// 5
};
};
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
...
@@ -49,8 +47,6 @@ int main(int argc, char* argv[])
...
@@ -49,8 +47,6 @@ int main(int argc, char* argv[])
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
index_t
activ_type
=
0
;
#if USE_DYNAMIC_MODE
#if USE_DYNAMIC_MODE
// dynamic mode
// dynamic mode
if
(
argc
!=
22
)
if
(
argc
!=
22
)
...
@@ -104,55 +100,13 @@ int main(int argc, char* argv[])
...
@@ -104,55 +100,13 @@ int main(int argc, char* argv[])
const
bool
do_log
=
std
::
stoi
(
argv
[
5
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
5
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
6
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
6
]);
#if 1
constexpr
auto
N
=
Number
<
128
>
{};
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
C
=
Number
<
192
>
{};
constexpr
auto
C
=
Number
<
16
>
{};
constexpr
auto
Hi
=
Number
<
71
>
{};
constexpr
auto
Hi
=
Number
<
1080
>
{};
constexpr
auto
Wi
=
Number
<
71
>
{};
constexpr
auto
Wi
=
Number
<
1920
>
{};
constexpr
auto
K
=
Number
<
256
>
{};
constexpr
auto
K
=
Number
<
64
>
{};
constexpr
auto
Y
=
Number
<
3
>
{};
constexpr
auto
X
=
Number
<
3
>
{};
#elif 0
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
C
=
Number
<
16
>
{};
constexpr
auto
Hi
=
Number
<
540
>
{};
constexpr
auto
Wi
=
Number
<
960
>
{};
constexpr
auto
K
=
Number
<
64
>
{};
constexpr
auto
Y
=
Number
<
3
>
{};
constexpr
auto
Y
=
Number
<
3
>
{};
constexpr
auto
X
=
Number
<
3
>
{};
constexpr
auto
X
=
Number
<
3
>
{};
#elif 0
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
C
=
Number
<
16
>
{};
constexpr
auto
Hi
=
Number
<
270
>
{};
constexpr
auto
Wi
=
Number
<
480
>
{};
constexpr
auto
K
=
Number
<
64
>
{};
constexpr
auto
Y
=
Number
<
3
>
{};
constexpr
auto
X
=
Number
<
3
>
{};
#elif 0
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
C
=
Number
<
16
>
{};
constexpr
auto
Hi
=
Number
<
135
>
{};
constexpr
auto
Wi
=
Number
<
240
>
{};
constexpr
auto
K
=
Number
<
64
>
{};
constexpr
auto
Y
=
Number
<
3
>
{};
constexpr
auto
X
=
Number
<
3
>
{};
#elif 0
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
C
=
Number
<
16
>
{};
constexpr
auto
Hi
=
Number
<
1440
>
{};
constexpr
auto
Wi
=
Number
<
2560
>
{};
constexpr
auto
K
=
Number
<
64
>
{};
constexpr
auto
Y
=
Number
<
3
>
{};
constexpr
auto
X
=
Number
<
3
>
{};
#elif 0
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
C
=
Number
<
16
>
{};
constexpr
auto
Hi
=
Number
<
2160
>
{};
constexpr
auto
Wi
=
Number
<
3840
>
{};
constexpr
auto
K
=
Number
<
64
>
{};
constexpr
auto
Y
=
Number
<
3
>
{};
constexpr
auto
X
=
Number
<
3
>
{};
#endif
constexpr
auto
conv_stride_h
=
I1
;
constexpr
auto
conv_stride_h
=
I1
;
constexpr
auto
conv_stride_w
=
I1
;
constexpr
auto
conv_stride_w
=
I1
;
...
@@ -170,7 +124,7 @@ int main(int argc, char* argv[])
...
@@ -170,7 +124,7 @@ int main(int argc, char* argv[])
constexpr
auto
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
I1
;
constexpr
auto
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
I1
;
#endif
#endif
#if
0
#if
1
using
in_data_t
=
float
;
using
in_data_t
=
float
;
using
acc_data_t
=
float
;
using
acc_data_t
=
float
;
using
out_data_t
=
float
;
using
out_data_t
=
float
;
...
@@ -385,34 +339,6 @@ int main(int argc, char* argv[])
...
@@ -385,34 +339,6 @@ int main(int argc, char* argv[])
}
}
#endif
#endif
#if USE_CONV_FWD_V5R1_NCHWC
if
(
algo
==
ConvForwardAlgo
::
V5R1NCHWC
)
{
if
(
layout
!=
ConvTensorLayout
::
NCHW
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
const
auto
tmp
=
f_make_for_device_nchw
();
device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw
<
in_data_t
,
acc_data_t
,
out_data_t
,
8
,
activ_type
>
(
tmp
[
I0
],
tmp
[
I1
],
tmp
[
I2
],
tmp
[
I3
],
tmp
[
I4
],
tmp
[
I5
],
tmp
[
I6
],
in
,
wei
,
out_device
,
nrepeat
);
}
#endif
#if USE_CONV_FWD_V4R4R2_XDL_NCHW
#if USE_CONV_FWD_V4R4R2_XDL_NCHW
if
(
algo
==
ConvForwardAlgo
::
V4R4R2XDLNCHW
)
if
(
algo
==
ConvForwardAlgo
::
V4R4R2XDLNCHW
)
{
{
...
@@ -476,8 +402,7 @@ int main(int argc, char* argv[])
...
@@ -476,8 +402,7 @@ int main(int argc, char* argv[])
make_tuple
(
conv_dilation_h
,
conv_dilation_w
),
make_tuple
(
conv_dilation_h
,
conv_dilation_w
),
make_tuple
(
in_left_pad_h
,
in_left_pad_w
),
make_tuple
(
in_left_pad_h
,
in_left_pad_w
),
make_tuple
(
in_right_pad_h
,
in_right_pad_w
),
make_tuple
(
in_right_pad_h
,
in_right_pad_w
),
layout
,
layout
);
activ_type
);
check_error
(
out_host
,
out_device
);
check_error
(
out_host
,
out_device
);
...
...
host/driver_offline/src/conv_fwd_driver_offline_nchwc.cpp
0 → 100644
View file @
5ce317cb
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "config.hpp"
#include "debug.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "conv_common.hpp"
#include "host_conv.hpp"
#include "device_tensor.hpp"
#include "device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp"
#define USE_DYNAMIC_MODE 0
#define USE_CONV_FWD_V5R1_NCHWC 1
enum
ConvForwardAlgo
{
V5R1NCHWC
// 0
};
int
main
(
int
argc
,
char
*
argv
[])
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
#if USE_DYNAMIC_MODE
// dynamic mode
if
(
argc
!=
21
)
{
printf
(
"arg1 to 5: algo, do_verification, init_method, do_log, nrepeat
\n
"
);
printf
(
"rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx
\n
"
);
exit
(
1
);
}
const
ConvForwardAlgo
algo
=
static_cast
<
ConvForwardAlgo
>
(
std
::
stoi
(
argv
[
2
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
3
]);
const
int
init_method
=
std
::
stoi
(
argv
[
4
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
5
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
6
]);
const
index_t
N
=
std
::
stoi
(
argv
[
7
]);
const
index_t
K
=
std
::
stoi
(
argv
[
8
]);
const
index_t
C
=
std
::
stoi
(
argv
[
9
]);
const
index_t
Y
=
std
::
stoi
(
argv
[
10
]);
const
index_t
X
=
std
::
stoi
(
argv
[
11
]);
const
index_t
Hi
=
std
::
stoi
(
argv
[
12
]);
const
index_t
Wi
=
std
::
stoi
(
argv
[
13
]);
const
index_t
conv_stride_h
=
std
::
stoi
(
argv
[
14
]);
const
index_t
conv_stride_w
=
std
::
stoi
(
argv
[
15
]);
const
index_t
conv_dilation_h
=
std
::
stoi
(
argv
[
16
]);
const
index_t
conv_dilation_w
=
std
::
stoi
(
argv
[
17
]);
const
index_t
in_left_pad_h
=
std
::
stoi
(
argv
[
18
]);
const
index_t
in_left_pad_w
=
std
::
stoi
(
argv
[
19
]);
const
index_t
in_right_pad_h
=
std
::
stoi
(
argv
[
20
]);
const
index_t
in_right_pad_w
=
std
::
stoi
(
argv
[
21
]);
const
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
const
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
const
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
YEff
)
/
conv_stride_h
+
1
;
const
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
1
;
#else
// static mode
if
(
argc
<
6
)
{
printf
(
"arg1 to 5: algo, do_verification, init_method, do_log, nrepeat
\n
"
);
exit
(
1
);
}
const
ConvForwardAlgo
algo
=
static_cast
<
ConvForwardAlgo
>
(
std
::
stoi
(
argv
[
1
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
2
]);
const
int
init_method
=
std
::
stoi
(
argv
[
3
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
4
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
5
]);
#if 1
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
C0
=
Number
<
2
>
{};
constexpr
auto
Hi
=
Number
<
1080
>
{};
constexpr
auto
Wi
=
Number
<
1920
>
{};
constexpr
auto
C1
=
Number
<
8
>
{};
constexpr
auto
K0
=
Number
<
8
>
{};
constexpr
auto
Y
=
Number
<
3
>
{};
constexpr
auto
X
=
Number
<
3
>
{};
constexpr
auto
K1
=
Number
<
8
>
{};
#elif 0
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
C
=
Number
<
16
>
{};
constexpr
auto
Hi
=
Number
<
540
>
{};
constexpr
auto
Wi
=
Number
<
960
>
{};
constexpr
auto
K
=
Number
<
64
>
{};
constexpr
auto
Y
=
Number
<
3
>
{};
constexpr
auto
X
=
Number
<
3
>
{};
#elif 0
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
C
=
Number
<
16
>
{};
constexpr
auto
Hi
=
Number
<
270
>
{};
constexpr
auto
Wi
=
Number
<
480
>
{};
constexpr
auto
K
=
Number
<
64
>
{};
constexpr
auto
Y
=
Number
<
3
>
{};
constexpr
auto
X
=
Number
<
3
>
{};
#elif 0
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
C
=
Number
<
16
>
{};
constexpr
auto
Hi
=
Number
<
135
>
{};
constexpr
auto
Wi
=
Number
<
240
>
{};
constexpr
auto
K
=
Number
<
64
>
{};
constexpr
auto
Y
=
Number
<
3
>
{};
constexpr
auto
X
=
Number
<
3
>
{};
#elif 0
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
C
=
Number
<
16
>
{};
constexpr
auto
Hi
=
Number
<
1440
>
{};
constexpr
auto
Wi
=
Number
<
2560
>
{};
constexpr
auto
K
=
Number
<
64
>
{};
constexpr
auto
Y
=
Number
<
3
>
{};
constexpr
auto
X
=
Number
<
3
>
{};
#elif 0
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
C
=
Number
<
16
>
{};
constexpr
auto
Hi
=
Number
<
2160
>
{};
constexpr
auto
Wi
=
Number
<
3840
>
{};
constexpr
auto
K
=
Number
<
64
>
{};
constexpr
auto
Y
=
Number
<
3
>
{};
constexpr
auto
X
=
Number
<
3
>
{};
#endif
constexpr
auto
conv_stride_h
=
I1
;
constexpr
auto
conv_stride_w
=
I1
;
constexpr
auto
conv_dilation_h
=
I1
;
constexpr
auto
conv_dilation_w
=
I1
;
constexpr
auto
in_left_pad_h
=
I1
;
constexpr
auto
in_left_pad_w
=
I1
;
constexpr
auto
in_right_pad_h
=
I1
;
constexpr
auto
in_right_pad_w
=
I1
;
constexpr
auto
YEff
=
(
Y
-
I1
)
*
conv_dilation_h
+
I1
;
constexpr
auto
XEff
=
(
X
-
I1
)
*
conv_dilation_w
+
I1
;
constexpr
auto
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
YEff
)
/
conv_stride_h
+
I1
;
constexpr
auto
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
I1
;
#endif
#if 0
using in_data_t = float;
using acc_data_t = float;
using out_data_t = float;
#elif
1
using
in_data_t
=
half_t
;
using
acc_data_t
=
float
;
using
out_data_t
=
half_t
;
#elif 1
using
in_data_t
=
int8_t
;
using
acc_data_t
=
int32_t
;
using
out_data_t
=
int8_t
;
#endif
std
::
vector
<
std
::
size_t
>
in_lengths_host
(
5
),
wei_lengths_host
(
5
),
out_lengths_host
(
5
);
in_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
in_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
C0
);
in_lengths_host
[
2
]
=
static_cast
<
std
::
size_t
>
(
Hi
);
in_lengths_host
[
3
]
=
static_cast
<
std
::
size_t
>
(
Wi
);
in_lengths_host
[
4
]
=
static_cast
<
std
::
size_t
>
(
C1
);
wei_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K0
*
K1
);
wei_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
C0
);
wei_lengths_host
[
2
]
=
static_cast
<
std
::
size_t
>
(
Y
);
wei_lengths_host
[
3
]
=
static_cast
<
std
::
size_t
>
(
X
);
wei_lengths_host
[
4
]
=
static_cast
<
std
::
size_t
>
(
C1
);
out_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
out_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
K0
);
out_lengths_host
[
2
]
=
static_cast
<
std
::
size_t
>
(
Ho
);
out_lengths_host
[
3
]
=
static_cast
<
std
::
size_t
>
(
Wo
);
out_lengths_host
[
4
]
=
static_cast
<
std
::
size_t
>
(
K1
);
Tensor
<
in_data_t
>
in
(
in_lengths_host
);
Tensor
<
in_data_t
>
wei
(
wei_lengths_host
);
Tensor
<
out_data_t
>
out_host
(
out_lengths_host
);
Tensor
<
out_data_t
>
out_device
(
out_lengths_host
);
ostream_HostTensorDescriptor
(
in
.
mDesc
,
std
::
cout
<<
"in: "
);
ostream_HostTensorDescriptor
(
wei
.
mDesc
,
std
::
cout
<<
"wei: "
);
ostream_HostTensorDescriptor
(
out_host
.
mDesc
,
std
::
cout
<<
"out: "
);
print_array
(
"InLeftPads"
,
make_tuple
(
in_left_pad_h
,
in_left_pad_w
));
print_array
(
"InRightPads"
,
make_tuple
(
in_right_pad_h
,
in_right_pad_w
));
print_array
(
"ConvStrides"
,
make_tuple
(
conv_stride_h
,
conv_stride_w
));
print_array
(
"ConvDilations"
,
make_tuple
(
conv_dilation_h
,
conv_dilation_w
));
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
switch
(
init_method
)
{
case
0
:
// no initialization
break
;
case
1
:
in
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
wei
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
break
;
case
2
:
in
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
wei
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
break
;
case
3
:
in
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
wei
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
break
;
case
4
:
in
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
wei
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
break
;
case
5
:
in
.
GenerateTensorValue
(
GeneratorTensor_3
<
float
>
{
0.0
,
1.0
},
num_thread
);
wei
.
GenerateTensorValue
(
GeneratorTensor_3
<
float
>
{
-
0.5
,
0.5
},
num_thread
);
break
;
default:
in
.
GenerateTensorValue
(
GeneratorTensor_2
{
1
,
5
},
num_thread
);
auto
gen_wei
=
[](
auto
...
is
)
{
return
GeneratorTensor_2
{
1
,
5
}(
is
...)
*
GeneratorTensor_Checkboard
{}(
is
...);
};
wei
.
GenerateTensorValue
(
gen_wei
,
num_thread
);
}
auto
f_make_for_device_nchwc
=
[
&
]()
{
const
auto
in_lengths_dev
=
make_tuple
(
N
,
C0
,
Hi
,
Wi
,
C1
);
const
auto
wei_lengths_dev
=
make_tuple
(
K0
*
K1
,
C0
,
Y
,
X
,
C1
);
const
auto
out_lengths_dev
=
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
);
const
auto
conv_strides_dev
=
make_tuple
(
conv_stride_h
,
conv_stride_w
);
const
auto
conv_dilations_dev
=
make_tuple
(
conv_dilation_h
,
conv_dilation_w
);
const
auto
in_left_pads_dev
=
make_tuple
(
in_left_pad_h
,
in_left_pad_w
);
const
auto
in_right_pads_dev
=
make_tuple
(
in_right_pad_h
,
in_right_pad_w
);
return
make_tuple
(
in_lengths_dev
,
wei_lengths_dev
,
out_lengths_dev
,
conv_strides_dev
,
conv_dilations_dev
,
in_left_pads_dev
,
in_right_pads_dev
);
};
constexpr
index_t
activ_type
=
0
;
#if USE_CONV_FWD_V5R1_NCHWC
if
(
algo
==
ConvForwardAlgo
::
V5R1NCHWC
)
{
const
auto
tmp
=
f_make_for_device_nchwc
();
device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1
<
in_data_t
,
acc_data_t
,
out_data_t
,
activ_type
>
(
tmp
[
I0
],
tmp
[
I1
],
tmp
[
I2
],
tmp
[
I3
],
tmp
[
I4
],
tmp
[
I5
],
tmp
[
I6
],
in
,
wei
,
out_device
,
nrepeat
);
}
#endif
if
(
do_verification
)
{
host_direct_convolution_nchwc
(
in
,
wei
,
out_host
,
make_tuple
(
conv_stride_h
,
conv_stride_w
),
make_tuple
(
conv_dilation_h
,
conv_dilation_w
),
make_tuple
(
in_left_pad_h
,
in_left_pad_w
),
make_tuple
(
in_right_pad_h
,
in_right_pad_w
),
activ_type
);
check_error
(
out_host
,
out_device
);
if
(
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"in : "
,
in
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"wei: "
,
wei
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"out_host : "
,
out_host
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"out_device: "
,
out_device
.
mData
,
","
)
<<
std
::
endl
;
}
}
}
host/driver_offline/src/conv_wrw_driver_offline.cpp
View file @
5ce317cb
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include <stdlib.h>
#include <stdlib.h>
#include <half.hpp>
#include <half.hpp>
#include "config.hpp"
#include "config.hpp"
#include "debug.hpp"
#include "print.hpp"
#include "print.hpp"
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
...
@@ -111,7 +112,7 @@ int main(int argc, char* argv[])
...
@@ -111,7 +112,7 @@ int main(int argc, char* argv[])
constexpr
auto
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
I1
;
constexpr
auto
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
I1
;
#endif
#endif
#if
1
#if
0
using in_data_t = float;
using in_data_t = float;
using acc_data_t = float;
using acc_data_t = float;
using out_data_t = float;
using out_data_t = float;
...
...
host/driver_offline/src/gemm_driver_offline.cpp
View file @
5ce317cb
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include <stdlib.h>
#include <stdlib.h>
#include <half.hpp>
#include <half.hpp>
#include "config.hpp"
#include "config.hpp"
#include "debug.hpp"
#include "print.hpp"
#include "print.hpp"
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
...
@@ -16,11 +17,19 @@
...
@@ -16,11 +17,19 @@
#include "device_gemm_xdlops_mk_nk_mn.hpp"
#include "device_gemm_xdlops_mk_nk_mn.hpp"
#include "device_gemm_xdlops_km_kn_mn.hpp"
#include "device_gemm_xdlops_km_kn_mn.hpp"
#include "device_gemm_xdlops_km_nk_mn.hpp"
#include "device_gemm_xdlops_km_nk_mn.hpp"
#include "device_gemm_xdlops_mk_kn_nm.hpp"
#include "device_gemm_xdlops_mk_nk_nm.hpp"
#include "device_gemm_xdlops_km_kn_nm.hpp"
#include "device_gemm_xdlops_km_nk_nm.hpp"
#define USE_GEMM_XDL_MK_KN_MN 1
#define USE_GEMM_XDL_MK_KN_MN 1
#define USE_GEMM_XDL_MK_NK_MN 1
#define USE_GEMM_XDL_MK_NK_MN 1
#define USE_GEMM_XDL_KM_KN_MN 1
#define USE_GEMM_XDL_KM_KN_MN 1
#define USE_GEMM_XDL_KM_NK_MN 1
#define USE_GEMM_XDL_KM_NK_MN 1
#define USE_GEMM_XDL_MK_KN_NM 0
#define USE_GEMM_XDL_MK_NK_NM 0
#define USE_GEMM_XDL_KM_KN_NM 0
#define USE_GEMM_XDL_KM_NK_NM 0
enum
GemmAlgo
enum
GemmAlgo
{
{
...
@@ -28,21 +37,21 @@ enum GemmAlgo
...
@@ -28,21 +37,21 @@ enum GemmAlgo
Xdl_MK_NK_MN
,
// 1
Xdl_MK_NK_MN
,
// 1
Xdl_KM_KN_MN
,
// 2
Xdl_KM_KN_MN
,
// 2
Xdl_KM_NK_MN
,
// 3
Xdl_KM_NK_MN
,
// 3
Xdl_MK_KN_NM
,
// 4
Xdl_MK_NK_NM
,
// 5
Xdl_KM_KN_NM
,
// 6
Xdl_KM_NK_NM
,
// 7
};
};
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
using
namespace
ck
;
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
if
(
argc
!=
12
)
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
// dynamic mode
if
(
argc
!=
10
)
{
{
printf
(
"arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat
\n
"
);
printf
(
"arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat
\n
"
);
printf
(
"rest: M, N, K
\n
"
);
printf
(
"rest: M, N, K
\n
"
);
printf
(
"debug_driver_gemm_xdlops_v2r3::M01, debug_driver_gemm_xdlops_v2r3::N01
\n
"
);
exit
(
1
);
exit
(
1
);
}
}
...
@@ -57,6 +66,9 @@ int main(int argc, char* argv[])
...
@@ -57,6 +66,9 @@ int main(int argc, char* argv[])
const
index_t
N
=
std
::
stoi
(
argv
[
8
]);
const
index_t
N
=
std
::
stoi
(
argv
[
8
]);
const
index_t
K
=
std
::
stoi
(
argv
[
9
]);
const
index_t
K
=
std
::
stoi
(
argv
[
9
]);
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
=
std
::
stoi
(
argv
[
10
]);
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
=
std
::
stoi
(
argv
[
11
]);
#if 0
#if 0
using ab_data_t = float;
using ab_data_t = float;
using acc_data_t = float;
using acc_data_t = float;
...
@@ -74,69 +86,44 @@ int main(int argc, char* argv[])
...
@@ -74,69 +86,44 @@ int main(int argc, char* argv[])
std
::
vector
<
std
::
size_t
>
a_lengths_host
(
2
),
b_lengths_host
(
2
),
c_lengths_host
(
2
);
std
::
vector
<
std
::
size_t
>
a_lengths_host
(
2
),
b_lengths_host
(
2
),
c_lengths_host
(
2
);
std
::
vector
<
std
::
size_t
>
a_strides_host
(
2
),
b_strides_host
(
2
),
c_strides_host
(
2
);
std
::
vector
<
std
::
size_t
>
a_strides_host
(
2
),
b_strides_host
(
2
),
c_strides_host
(
2
);
if
(
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
// A
if
(
layout
==
GemmMatrixLayout
::
MK_KN_MN
||
layout
==
GemmMatrixLayout
::
MK_NK_MN
||
layout
==
GemmMatrixLayout
::
MK_KN_NM
||
layout
==
GemmMatrixLayout
::
MK_NK_NM
)
{
{
a_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
M
);
a_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
M
);
a_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
K
);
a_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
K
);
a_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
a_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
a_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
a_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
b_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
b_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
N
);
b_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
b_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
c_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
M
);
c_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
N
);
c_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
c_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
}
}
else
if
(
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
else
{
{
a_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
M
);
a_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
a_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
K
);
a_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
M
);
a_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
a_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
M
);
a_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
a_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
}
// B
if
(
layout
==
GemmMatrixLayout
::
MK_NK_MN
||
layout
==
GemmMatrixLayout
::
KM_NK_MN
||
layout
==
GemmMatrixLayout
::
MK_NK_NM
||
layout
==
GemmMatrixLayout
::
KM_NK_NM
)
{
b_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
b_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
b_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
K
);
b_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
K
);
b_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
b_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
b_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
b_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
c_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
M
);
c_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
N
);
c_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
c_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
}
}
else
if
(
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
else
{
{
a_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
a_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
M
);
a_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
M
);
a_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
b_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
b_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
b_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
N
);
b_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
N
);
b_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
b_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
b_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
b_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
c_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
M
);
c_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
N
);
c_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
c_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
}
}
else
if
(
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
a_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
a_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
M
);
a_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
M
);
a_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
b_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
b_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
K
);
b_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
b_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
// C
if
(
layout
==
GemmMatrixLayout
::
MK_KN_MN
||
layout
==
GemmMatrixLayout
::
KM_KN_MN
||
layout
==
GemmMatrixLayout
::
MK_NK_MN
||
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
c_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
M
);
c_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
M
);
c_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
N
);
c_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
N
);
c_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
c_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
...
@@ -144,7 +131,10 @@ int main(int argc, char* argv[])
...
@@ -144,7 +131,10 @@ int main(int argc, char* argv[])
}
}
else
else
{
{
std
::
runtime_error
(
"wrong! not implemented"
);
c_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
c_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
M
);
c_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
M
);
c_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
}
}
Tensor
<
ab_data_t
>
a
(
a_lengths_host
,
a_strides_host
);
Tensor
<
ab_data_t
>
a
(
a_lengths_host
,
a_strides_host
);
...
@@ -185,38 +175,6 @@ int main(int argc, char* argv[])
...
@@ -185,38 +175,6 @@ int main(int argc, char* argv[])
b
.
GenerateTensorValue
(
GeneratorTensor_3
<
float
>
{
-
0.5
,
0.5
},
num_thread
);
b
.
GenerateTensorValue
(
GeneratorTensor_3
<
float
>
{
-
0.5
,
0.5
},
num_thread
);
}
}
auto
f_make_for_device_mk_kn_mn
=
[
&
]()
{
const
auto
a_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
K
,
I1
));
const
auto
b_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
N
,
I1
));
const
auto
c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
N
,
I1
));
return
make_tuple
(
a_desc
,
b_desc
,
c_desc
);
};
auto
f_make_for_device_mk_nk_mn
=
[
&
]()
{
const
auto
a_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
K
,
I1
));
const
auto
b_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
K
,
I1
));
const
auto
c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
N
,
I1
));
return
make_tuple
(
a_desc
,
b_desc
,
c_desc
);
};
auto
f_make_for_device_km_kn_mn
=
[
&
]()
{
const
auto
a_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K
,
M
),
make_tuple
(
M
,
I1
));
const
auto
b_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
N
,
I1
));
const
auto
c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
N
,
I1
));
return
make_tuple
(
a_desc
,
b_desc
,
c_desc
);
};
auto
f_make_for_device_km_nk_mn
=
[
&
]()
{
const
auto
a_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K
,
M
),
make_tuple
(
M
,
I1
));
const
auto
b_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
K
,
I1
));
const
auto
c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
N
,
I1
));
return
make_tuple
(
a_desc
,
b_desc
,
c_desc
);
};
#if USE_GEMM_XDL_MK_KN_MN
#if USE_GEMM_XDL_MK_KN_MN
if
(
algo
==
GemmAlgo
::
Xdl_MK_KN_MN
)
if
(
algo
==
GemmAlgo
::
Xdl_MK_KN_MN
)
{
{
...
@@ -225,10 +183,7 @@ int main(int argc, char* argv[])
...
@@ -225,10 +183,7 @@ int main(int argc, char* argv[])
throw
std
::
runtime_error
(
"wrong! layout"
);
throw
std
::
runtime_error
(
"wrong! layout"
);
}
}
const
auto
descs
=
f_make_for_device_mk_kn_mn
();
device_gemm_xdlops_mk_kn_mn
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
a
,
b
,
c_device
,
nrepeat
);
device_gemm_xdlops_mk_kn_mn
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
descs
[
I0
],
descs
[
I1
],
descs
[
I2
],
a
,
b
,
c_device
,
nrepeat
);
}
}
#endif
#endif
...
@@ -240,10 +195,7 @@ int main(int argc, char* argv[])
...
@@ -240,10 +195,7 @@ int main(int argc, char* argv[])
throw
std
::
runtime_error
(
"wrong! layout"
);
throw
std
::
runtime_error
(
"wrong! layout"
);
}
}
const
auto
descs
=
f_make_for_device_mk_nk_mn
();
device_gemm_xdlops_mk_nk_mn
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
a
,
b
,
c_device
,
nrepeat
);
device_gemm_xdlops_mk_nk_mn
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
descs
[
I0
],
descs
[
I1
],
descs
[
I2
],
a
,
b
,
c_device
,
nrepeat
);
}
}
#endif
#endif
...
@@ -255,10 +207,7 @@ int main(int argc, char* argv[])
...
@@ -255,10 +207,7 @@ int main(int argc, char* argv[])
throw
std
::
runtime_error
(
"wrong! layout"
);
throw
std
::
runtime_error
(
"wrong! layout"
);
}
}
const
auto
descs
=
f_make_for_device_km_kn_mn
();
device_gemm_xdlops_km_kn_mn
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
a
,
b
,
c_device
,
nrepeat
);
device_gemm_xdlops_km_kn_mn
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
descs
[
I0
],
descs
[
I1
],
descs
[
I2
],
a
,
b
,
c_device
,
nrepeat
);
}
}
#endif
#endif
...
@@ -270,10 +219,55 @@ int main(int argc, char* argv[])
...
@@ -270,10 +219,55 @@ int main(int argc, char* argv[])
throw
std
::
runtime_error
(
"wrong! layout"
);
throw
std
::
runtime_error
(
"wrong! layout"
);
}
}
const
auto
descs
=
f_make_for_device_km_nk_mn
();
device_gemm_xdlops_km_nk_mn
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
a
,
b
,
c_device
,
nrepeat
);
}
#endif
#if USE_GEMM_XDL_MK_KN_NM
if
(
algo
==
GemmAlgo
::
Xdl_MK_KN_NM
)
{
if
(
layout
!=
GemmMatrixLayout
::
MK_KN_NM
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
device_gemm_xdlops_mk_kn_nm
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
a
,
b
,
c_device
,
nrepeat
);
}
#endif
#if USE_GEMM_XDL_MK_NK_NM
if
(
algo
==
GemmAlgo
::
Xdl_MK_NK_NM
)
{
if
(
layout
!=
GemmMatrixLayout
::
MK_NK_NM
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
device_gemm_xdlops_mk_nk_nm
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
a
,
b
,
c_device
,
nrepeat
);
}
#endif
#if USE_GEMM_XDL_KM_KN_NM
if
(
algo
==
GemmAlgo
::
Xdl_KM_KN_NM
)
{
if
(
layout
!=
GemmMatrixLayout
::
KM_KN_NM
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
device_gemm_xdlops_km_kn_nm
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
a
,
b
,
c_device
,
nrepeat
);
}
#endif
#if USE_GEMM_XDL_KM_NK_NM
if
(
algo
==
GemmAlgo
::
Xdl_KM_NK_NM
)
{
if
(
layout
!=
GemmMatrixLayout
::
KM_NK_NM
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
device_gemm_xdlops_km_nk_mn
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
device_gemm_xdlops_km_nk_nm
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
a
,
b
,
c_device
,
nrepeat
);
descs
[
I0
],
descs
[
I1
],
descs
[
I2
],
a
,
b
,
c_device
,
nrepeat
);
}
}
#endif
#endif
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment