Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ed966e7f
Commit
ed966e7f
authored
Sep 14, 2021
by
Jing Zhang
Browse files
clean
parent
d3146496
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
435 additions
and
784 deletions
+435
-784
composable_kernel/include/utility/config.hpp
composable_kernel/include/utility/config.hpp
+1
-1
host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
...ution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
+7
-7
host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
...ution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
+427
-234
host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp
...orward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp
+0
-542
No files found.
composable_kernel/include/utility/config.hpp
View file @
ed966e7f
...
@@ -78,7 +78,7 @@
...
@@ -78,7 +78,7 @@
// experimental implementation
// experimental implementation
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
0
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
1
#endif
#endif
#ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
#ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
...
...
host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
View file @
ed966e7f
#include <unistd.h>
#include <unistd.h>
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw
_outpad
.hpp"
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp"
template
<
typename
TInWei
,
template
<
typename
TInWei
,
typename
TAcc
,
typename
TAcc
,
...
@@ -48,7 +48,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
...
@@ -48,7 +48,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
const
auto
Y
=
wei_k_c_y_x_lengths
[
I2
];
const
auto
Y
=
wei_k_c_y_x_lengths
[
I2
];
const
auto
X
=
wei_k_c_y_x_lengths
[
I3
];
const
auto
X
=
wei_k_c_y_x_lengths
[
I3
];
constexpr
auto
InWeiVectorSize
=
8
;
constexpr
auto
InWeiVectorSize
=
4
;
#if 1
#if 1
const
auto
C0
=
C
/
Number
<
InWeiVectorSize
>
{};
const
auto
C0
=
C
/
Number
<
InWeiVectorSize
>
{};
...
@@ -106,16 +106,16 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
...
@@ -106,16 +106,16 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
constexpr
index_t
HoPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
8
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
E1
=
2
*
9
;
constexpr
index_t
E1
=
4
*
9
;
constexpr
index_t
E2
=
8
;
constexpr
index_t
E2
=
C1
;
constexpr
index_t
EPerBlock
=
2
;
constexpr
index_t
EPerBlock
=
4
;
constexpr
index_t
KPerThread
=
KPerBlock
;
constexpr
index_t
KPerThread
=
KPerBlock
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
EPerThread
=
1
;
constexpr
index_t
EPerThread
=
1
;
using
ABlockTransferThreadSliceLengths_E0_E1_K_E2
=
Sequence
<
1
,
9
,
1
,
8
>
;
using
ABlockTransferThreadSliceLengths_E0_E1_K_E2
=
Sequence
<
1
,
9
,
1
,
E2
>
;
using
ABlockTransferThreadClusterLengths_E0_E1_K_E2
=
Sequence
<
1
,
EPerBlock
,
16
,
1
>
;
using
ABlockTransferThreadClusterLengths_E0_E1_K_E2
=
Sequence
<
1
,
EPerBlock
,
16
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_E2
=
E2
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_E2
=
E2
;
...
@@ -123,7 +123,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
...
@@ -123,7 +123,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
constexpr
index_t
BThreadTransferSrcScalarPerVector_E2
=
E2
;
constexpr
index_t
BThreadTransferSrcScalarPerVector_E2
=
E2
;
constexpr
index_t
CThreadTransferDstScalarPerVector_K
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector_K
=
K1
;
#endif
#endif
constexpr
auto
conv_driver
=
constexpr
auto
conv_driver
=
...
...
host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
View file @
ed966e7f
#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_
DLOPS_
NCHW_KCYX_NKHW_HPP
#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_
DLOPS_
NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_dlops_v2.hpp"
#include "gridwise_gemm_dlops_v2.hpp"
#include "gridwise_operation_wrapper.hpp"
template
<
ck
::
index_t
BlockSize
,
template
<
ck
::
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatAcc
,
typename
FloatC
,
typename
FloatC
,
ck
::
index_t
E1
,
ck
::
index_t
E2
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
HoPerBlock
,
ck
::
index_t
HoPerBlock
,
ck
::
index_t
WoPerBlock
,
ck
::
index_t
WoPerBlock
,
ck
::
index_t
EPerBlock
,
ck
::
index_t
E
1
PerBlock
,
ck
::
index_t
KPerThread
,
ck
::
index_t
KPerThread
,
ck
::
index_t
HoPerThread
,
ck
::
index_t
HoPerThread
,
ck
::
index_t
WoPerThread
,
ck
::
index_t
WoPerThread
,
ck
::
index_t
EPerThread
,
ck
::
index_t
EPerThread
,
typename
ABlockTransferThreadSliceLengths_E
_K
,
typename
ABlockTransferThreadSliceLengths_E
0_E1_K_E2
,
typename
ABlockTransferThreadClusterLengths_E
_K
,
typename
ABlockTransferThreadClusterLengths_E
0_E1_K_E2
,
ck
::
index_t
ABlockTransferSrcScalarPerVector_E
,
ck
::
index_t
ABlockTransferSrcScalarPerVector_E
2
,
ck
::
index_t
ABlockTransferDstScalarPerVector_
K
,
ck
::
index_t
ABlockTransferDstScalarPerVector_
E2
,
ck
::
index_t
BThreadTransferSrcScalarPerVector_
W
,
ck
::
index_t
BThreadTransferSrcScalarPerVector_
E2
,
ck
::
index_t
CThreadTransferDstScalarPerVector_
W
>
ck
::
index_t
CThreadTransferDstScalarPerVector_
K
>
struct
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
struct
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_
out
pad
{
{
template
<
typename
...
Wei
,
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
In
,
...
@@ -34,16 +35,17 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
...
@@ -34,16 +35,17 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
>
typename
InRightPads
>
__host__
void
Run
(
const
ck
::
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
__host__
float
Run
(
const
ck
::
TensorDescriptor
<
Wei
...
>&
wei_k_c0_y_x_c1_global_desc
,
const
ck
::
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
ck
::
TensorDescriptor
<
In
...
>&
in_n_c0_hi_wi_c1_global_desc
,
const
ck
::
TensorDescriptor
<
Out
...
>&
out_n_k0_ho_wo_k1_global_desc
,
const
ck
::
TensorDescriptor
<
Out
...
>&
out_n_k0_ho_wo_k1_global_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
InRightPads
&
in_right_pads
,
const
FloatAB
*
__restrict__
p_wei_global
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_in_global
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_out_global
)
const
FloatC
*
__restrict__
p_c_grid
,
const
int
nrepeat
)
const
{
{
using
namespace
ck
;
using
namespace
ck
;
...
@@ -53,21 +55,20 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
...
@@ -53,21 +55,20 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
const
auto
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I0
);
const
auto
N
=
in_n_c0_hi_wi_c1_global_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I1
);
const
auto
C0
=
in_n_c0_hi_wi_c1_global_desc
.
GetLength
(
I1
);
const
auto
K0
=
out_n_k0_ho_wo_k1_global_desc
.
GetLength
(
I1
);
const
auto
Hi
=
in_n_c0_hi_wi_c1_global_desc
.
GetLength
(
I2
);
const
auto
Wi
=
in_n_c0_hi_wi_c1_global_desc
.
GetLength
(
I3
);
const
auto
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I2
);
const
auto
C1
=
in_n_c0_hi_wi_c1_global_desc
.
GetLength
(
I4
);
const
auto
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I3
);
const
auto
K0
=
out_n_k0_ho_wo_k1_global_desc
.
GetLength
(
I1
);
const
auto
Ho
=
out_n_k0_ho_wo_k1_global_desc
.
GetLength
(
I2
);
const
auto
Ho
=
out_n_k0_ho_wo_k1_global_desc
.
GetLength
(
I2
);
const
auto
Wo
=
out_n_k0_ho_wo_k1_global_desc
.
GetLength
(
I3
);
const
auto
Wo
=
out_n_k0_ho_wo_k1_global_desc
.
GetLength
(
I3
);
const
auto
K1
=
out_n_k0_ho_wo_k1_global_desc
.
GetLength
(
I4
);
const
auto
K1
=
out_n_k0_ho_wo_k1_global_desc
.
GetLength
(
I4
);
const
auto
K
=
wei_k_c_y_x_global_desc
.
GetLength
(
I0
);
const
auto
K
=
wei_k_c
0
_y_x_
c1_
global_desc
.
GetLength
(
I0
);
const
auto
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
const
auto
Y
=
wei_k_c
0
_y_x_
c1_
global_desc
.
GetLength
(
I2
);
const
auto
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
const
auto
X
=
wei_k_c
0
_y_x_
c1_
global_desc
.
GetLength
(
I3
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
...
@@ -75,85 +76,139 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
...
@@ -75,85 +76,139 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
Hop
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
*
HoPerBlock
;
const
auto
Wop
=
(
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
*
WoPerBlock
;
const
auto
OutRightPadH
=
Hop
-
Ho
;
const
auto
OutRightPadW
=
Wop
-
Wo
;
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadH
=
in_right_pads
[
I0
]
+
OutRightPadH
*
ConvStrideH
;
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
InRightPadW
=
in_right_pads
[
I1
]
+
OutRightPadW
*
ConvStrideW
;
std
::
cerr
<<
"OutRightPadH = "
<<
OutRightPadH
<<
" OutRightPadW = "
<<
OutRightPadW
<<
std
::
endl
;
std
::
cerr
<<
"InRightPadH = "
<<
InRightPadH
<<
" InRightPadW = "
<<
InRightPadW
<<
std
::
endl
;
const
auto
E
=
C0
*
Y
*
X
;
static_assert
(
E2
==
C1
,
""
);
const
auto
E0
=
E
/
E1
;
// weight tensor
// weight tensor
const
auto
wei_e_k_global_desc
=
transform_tensor_descriptor
(
const
auto
a_e0_k_e2_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
*
Y
*
X
)),
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C0
*
Y
*
X
,
E2
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
*
Y
*
X
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_pass_through_transform
(
C0
*
Y
*
X
),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
make_pass_through_transform
(
E2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{},
Sequence
<
2
>
{}));
const
auto
a_e0_e1_k_e2_grid_desc
=
transform_tensor_descriptor
(
a_e0_k_e2_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
E0
,
E1
)),
make_pass_through_transform
(
K
),
make_pass_through_transform
(
E2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
// input tensor
// input tensor
const
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
const
auto
in_n_c
0
_hip_wip_
c1_
global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
in_n_c
0
_hi_wi_
c1_
global_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
make_pass_through_transform
(
C
0
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
)),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_pass_through_transform
(
E2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
const
auto
in_n_c
0
_y_ho_x_wo_
c1_
global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
in_n_c
0
_hip_wip_
c1_
global_desc
,
make_tuple
(
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
make_pass_through_transform
(
C0
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Y
,
Hop
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
))),
make_embed_transform
(
make_tuple
(
X
,
Wop
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_pass_through_transform
(
E2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
const
auto
in_e_n_ho_wo_global_desc
=
transform_tensor_descriptor
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{},
Sequence
<
6
>
{}));
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
const
auto
b_e0_n_ho_wo_e2_grid_desc
=
transform_tensor_descriptor
(
in_n_c0_y_ho_x_wo_c1_global_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C0
,
Y
,
X
)),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
Ho
),
make_pass_through_transform
(
Hop
),
make_pass_through_transform
(
Wo
)),
make_pass_through_transform
(
Wop
),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{}),
make_pass_through_transform
(
E2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
b_e0_e1_n_ho_wo_e2_grid_desc
=
transform_tensor_descriptor
(
b_e0_n_ho_wo_e2_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
E0
,
E1
)),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
Hop
),
make_pass_through_transform
(
Wop
),
make_pass_through_transform
(
E2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}));
// output tensor
// output tensor
const
auto
out
_k_n_ho_wo_g
lobal
_desc
=
transform_tensor_descriptor
(
const
auto
c
_k_n_ho
p
_wo
p
_g
rid
_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
)),
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
K1
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
),
make_pa
ss_through
_transform
(
Ho
),
make_pa
d
_transform
(
Ho
,
0
,
OutRightPadH
),
make_pa
ss_through
_transform
(
Wo
)),
make_pa
d
_transform
(
Wo
,
0
,
OutRightPadW
)),
make_tuple
(
Sequence
<
1
,
4
>
{},
Sequence
<
0
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
1
,
4
>
{},
Sequence
<
0
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
E
=
C
*
Y
*
X
;
std
::
cerr
<<
"Hop = "
<<
Hop
<<
" Wop = "
<<
Wop
<<
std
::
endl
;
if
(
!
((
K
%
KPerBlock
)
==
0
&&
(
Ho
%
HoPerBlock
)
==
0
&&
(
Wo
%
WoPerBlock
)
==
0
&&
if
(
!
((
K
%
KPerBlock
)
==
0
&&
(
Ho
p
%
HoPerBlock
)
==
0
&&
(
Wo
p
%
WoPerBlock
)
==
0
&&
(
E
%
EPerBlock
)
==
0
))
(
E
1
%
E
1
PerBlock
)
==
0
))
{
{
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
}
}
// hack to control index calculation when iterating over a_k_m_global tensor
// hack to control index calculation when iterating over a_k_m_global tensor
constexpr
auto
a_e_k_global_step_hacks
=
constexpr
auto
a_e0_e1_k_e2_global_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
constexpr
auto
a_e_k_global_move_slice_window_step_hack
=
Sequence
<
0
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
constexpr
auto
b_e_n_ho_wo_global_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
constexpr
auto
a_e0_e1_k_e2_global_move_slice_window_step_hack
=
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
constexpr
auto
b_e0_e1_n_ho_wo_e2_global_step_hacks
=
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
constexpr
auto
b_e_n_ho_wo_global_move_slice_window_step_hack
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
constexpr
auto
b_e0_e1_n_ho_wo_e2_global_move_slice_window_step_hack
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format
// hack for NKHW format
...
@@ -167,182 +222,320 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
...
@@ -167,182 +222,320 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
#if 1
// GEMM
// GEMM
using
g
ridwise
_g
emm
=
GridwiseGemmDlops_km_kn_mn_v3
<
using
G
ridwise
G
emm
=
GridwiseGemmDlops_km_kn_mn_v3
<
BlockSize
,
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
FloatC
,
FloatC
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
wei_e_k_global_desc
),
decltype
(
a_e0_e1_k_e2_grid_desc
),
decltype
(
in_e_n_ho_wo_global_desc
),
decltype
(
b_e0_e1_n_ho_wo_e2_grid_desc
),
decltype
(
out_k_n_ho_wo_global_desc
),
decltype
(
c_k_n_hop_wop_grid_desc
),
E1
,
E2
,
KPerBlock
,
KPerBlock
,
HoPerBlock
,
HoPerBlock
,
WoPerBlock
,
WoPerBlock
,
EPerBlock
,
E
1
PerBlock
,
KPerThread
,
KPerThread
,
HoPerThread
,
HoPerThread
,
WoPerThread
,
WoPerThread
,
EPerThread
,
EPerThread
,
ABlockTransferThreadSliceLengths_E_K
,
ABlockTransferThreadSliceLengths_E0_E1_K_E2
,
ABlockTransferThreadClusterLengths_E_K
,
ABlockTransferThreadClusterLengths_E0_E1_K_E2
,
Sequence
<
1
,
0
>
,
Sequence
<
2
,
0
,
1
,
3
>
,
Sequence
<
1
,
0
>
,
Sequence
<
2
,
0
,
1
,
3
>
,
0
,
ABlockTransferSrcScalarPerVector_E
,
ABlockTransferDstScalarPerVector_K
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
0
,
2
,
3
,
1
>
,
3
,
3
,
BThreadTransferSrcScalarPerVector_W
,
ABlockTransferSrcScalarPerVector_E2
,
ABlockTransferDstScalarPerVector_E2
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
0
,
2
,
3
,
4
,
1
,
5
>
,
5
,
BThreadTransferSrcScalarPerVector_E2
,
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
// MoveSrcSliceWindow() to save addr computation
Sequence
<
0
,
2
,
3
,
1
>
,
Sequence
<
2
,
3
,
1
,
0
>
,
0
,
0
,
CThreadTransferDstScalarPerVector_
W
,
CThreadTransferDstScalarPerVector_
K
,
decltype
(
a_e
_k
_global_step_hacks
),
decltype
(
a_e
0_e1_k_e2
_global_step_hacks
),
decltype
(
b_e_n_ho_wo_global_step_hacks
),
decltype
(
b_e
0_e1
_n_ho_wo_
e2_
global_step_hacks
),
decltype
(
c_k_n_ho_wo_global_tensor_step_hacks
),
decltype
(
c_k_n_ho_wo_global_tensor_step_hacks
),
decltype
(
a_e
_k
_global_move_slice_window_step_hack
),
decltype
(
a_e
0_e1_k_e2
_global_move_slice_window_step_hack
),
decltype
(
b_e_n_ho_wo_global_move_slice_window_step_hack
)
>
;
decltype
(
b_e
0_e1
_n_ho_wo_
e2_
global_move_slice_window_step_hack
)
>
;
const
auto
GridSize
=
(
K
/
KPerBlock
)
*
(
Ho
/
HoPerBlock
)
*
(
Wo
/
WoPerBlock
)
*
N
;
using
AGridDesc_E0_E1_K_E2
=
decltype
(
a_e0_e1_k_e2_grid_desc
);
using
BGridDesc_E0_E1_N_Ho_Wo_E2
=
decltype
(
b_e0_e1_n_ho_wo_e2_grid_desc
);
using
CGridDesc_K_N_Ho_Wo
=
decltype
(
c_k_n_hop_wop_grid_desc
);
const
bool
has_main_k_block_loop
=
(
E
+
E
PerBlock
)
/
(
2
*
E
PerBlock
)
>
1
;
const
auto
grid_size
=
(
K
/
KPerBlock
)
*
(
Hop
/
Ho
PerBlock
)
*
(
Wop
/
Wo
PerBlock
)
*
N
;
const
bool
has_
double_t
ai
l
_k_block_loop
=
(
E
/
EPerBlock
)
%
2
==
0
;
const
bool
has_
m
ai
n
_k_block_loop
=
(
E
1
+
E
1
PerBlock
)
/
(
2
*
E1PerBlock
)
>
1
;
index_t
nrepeat
=
10
0
;
const
bool
has_double_tail_k_block_loop
=
(
E1
/
E1PerBlock
)
%
2
==
0
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
std
::
cerr
<<
"has_main_k_block_loop = "
<<
has_main_k_block_loop
<<
" has_double_tail_k_block_loop = "
<<
has_double_tail_k_block_loop
<<
std
::
endl
;
const
auto
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
I0
,
I0
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
using
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
=
decltype
(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
);
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float
ave_time
=
0
;
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_gemm_dlops_v2
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo_E2
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
true
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_e0_e1_k_e2_grid_desc
,
b_e0_e1_n_ho_wo_e2_grid_desc
,
c_k_n_hop_wop_grid_desc
,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
);
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
{
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
const
auto
kernel
=
kernel_gemm_dlops_v2
<
GridwiseGemm
,
KernelTimer
timer
;
FloatAB
,
timer
.
Start
();
FloatC
,
std
::
cout
<<
"has_main_k_block_loop: "
<<
has_main_k_block_loop
remove_reference_t
<
AGridDesc_E0_E1_K_E2
>
,
<<
" has_double_tail_k_block_loop: "
<<
has_double_tail_k_block_loop
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo_E2
>
,
<<
std
::
endl
;
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
true
,
{
false
>
;
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
ave_time
=
launch_and_time_kernel
(
kernel
,
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
nrepeat
,
decltype
(
wei_e_k_global_desc
),
dim3
(
grid_size
),
const
FloatAB
*
,
dim3
(
BlockSize
),
decltype
(
in_e_n_ho_wo_global_desc
),
0
,
const
FloatAB
*
,
p_a_grid
,
decltype
(
out_k_n_ho_wo_global_desc
),
p_b_grid
,
FloatC
*
,
p_c_grid
,
integral_constant
<
bool
,
true
>
,
a_e0_e1_k_e2_grid_desc
,
integral_constant
<
bool
,
true
>>
;
b_e0_e1_n_ho_wo_e2_grid_desc
,
c_k_n_hop_wop_grid_desc
,
launch_kernel
(
kernel
,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
);
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
wei_e_k_global_desc
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
p_in_global
,
out_k_n_ho_wo_global_desc
,
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_e_k_global_desc
),
const
FloatAB
*
,
decltype
(
in_e_n_ho_wo_global_desc
),
const
FloatAB
*
,
decltype
(
out_k_n_ho_wo_global_desc
),
FloatC
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
wei_e_k_global_desc
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
p_in_global
,
out_k_n_ho_wo_global_desc
,
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_e_k_global_desc
),
const
FloatAB
*
,
decltype
(
in_e_n_ho_wo_global_desc
),
const
FloatAB
*
,
decltype
(
out_k_n_ho_wo_global_desc
),
FloatC
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
wei_e_k_global_desc
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
p_in_global
,
out_k_n_ho_wo_global_desc
,
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_e_k_global_desc
),
const
FloatAB
*
,
decltype
(
in_e_n_ho_wo_global_desc
),
const
FloatAB
*
,
decltype
(
out_k_n_ho_wo_global_desc
),
FloatC
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
wei_e_k_global_desc
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
p_in_global
,
out_k_n_ho_wo_global_desc
,
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
timer
.
End
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
static_cast
<
float
>
(
calculate_convolution_flops
(
in_n_c_hi_wi_global_desc
,
wei_k_c_y_x_global_desc
,
out_n_k0_ho_wo_k1_global_desc
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_gemm_dlops_v2
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo_E2
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
false
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_e0_e1_k_e2_grid_desc
,
b_e0_e1_n_ho_wo_e2_grid_desc
,
c_k_n_hop_wop_grid_desc
,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
);
}
else
{
const
auto
kernel
=
kernel_gemm_dlops_v2
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo_E2
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
false
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_e0_e1_k_e2_grid_desc
,
b_e0_e1_n_ho_wo_e2_grid_desc
,
c_k_n_hop_wop_grid_desc
,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
);
}
return
ave_time
;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem
a_e0_e1_k_e2_grid_desc_dev_buf
(
sizeof
(
AGridDesc_E0_E1_K_E2
));
DeviceMem
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf
(
sizeof
(
BGridDesc_E0_E1_N_Ho_Wo_E2
));
DeviceMem
c_k_n_hop_wop_grid_desc_dev_buf
(
sizeof
(
CGridDesc_K_N_Ho_Wo
));
DeviceMem
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf
(
sizeof
(
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
));
a_e0_e1_k_e2_grid_desc_dev_buf
.
ToDevice
(
&
a_e0_e1_k_e2_grid_desc
);
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf
.
ToDevice
(
&
b_e0_e1_n_ho_wo_e2_grid_desc
);
c_k_n_hop_wop_grid_desc_dev_buf
.
ToDevice
(
&
c_k_n_hop_wop_grid_desc
);
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf
.
ToDevice
(
&
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
);
float
ave_time
=
0
;
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_gemm_dlops_v2
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo_E2
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
true
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_e0_e1_k_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_k_n_hop_wop_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_gemm_dlops_v2
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo_E2
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
true
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_e0_e1_k_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_k_n_hop_wop_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_gemm_dlops_v2
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo_E2
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
false
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_e0_e1_k_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_k_n_hop_wop_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
else
{
const
auto
kernel
=
kernel_gemm_dlops_v2
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo_E2
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
false
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_e0_e1_k_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_k_n_hop_wop_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
return
ave_time
;
#endif
#endif
}
}
};
};
...
...
host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp
deleted
100644 → 0
View file @
d3146496
#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP
#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_dlops_v2.hpp"
template
<
ck
::
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
ck
::
index_t
E1
,
ck
::
index_t
E2
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
HoPerBlock
,
ck
::
index_t
WoPerBlock
,
ck
::
index_t
E1PerBlock
,
ck
::
index_t
KPerThread
,
ck
::
index_t
HoPerThread
,
ck
::
index_t
WoPerThread
,
ck
::
index_t
EPerThread
,
typename
ABlockTransferThreadSliceLengths_E0_E1_K_E2
,
typename
ABlockTransferThreadClusterLengths_E0_E1_K_E2
,
ck
::
index_t
ABlockTransferSrcScalarPerVector_E2
,
ck
::
index_t
ABlockTransferDstScalarPerVector_E2
,
ck
::
index_t
BThreadTransferSrcScalarPerVector_E2
,
ck
::
index_t
CThreadTransferDstScalarPerVector_K
>
struct
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad
{
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
__host__
float
Run
(
const
ck
::
TensorDescriptor
<
Wei
...
>&
wei_k_c0_y_x_c1_global_desc
,
const
ck
::
TensorDescriptor
<
In
...
>&
in_n_c0_hi_wi_c1_global_desc
,
const
ck
::
TensorDescriptor
<
Out
...
>&
out_n_k0_ho_wo_k1_global_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
int
nrepeat
)
const
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
const
auto
N
=
in_n_c0_hi_wi_c1_global_desc
.
GetLength
(
I0
);
const
auto
C0
=
in_n_c0_hi_wi_c1_global_desc
.
GetLength
(
I1
);
const
auto
Hi
=
in_n_c0_hi_wi_c1_global_desc
.
GetLength
(
I2
);
const
auto
Wi
=
in_n_c0_hi_wi_c1_global_desc
.
GetLength
(
I3
);
const
auto
C1
=
in_n_c0_hi_wi_c1_global_desc
.
GetLength
(
I4
);
const
auto
K0
=
out_n_k0_ho_wo_k1_global_desc
.
GetLength
(
I1
);
const
auto
Ho
=
out_n_k0_ho_wo_k1_global_desc
.
GetLength
(
I2
);
const
auto
Wo
=
out_n_k0_ho_wo_k1_global_desc
.
GetLength
(
I3
);
const
auto
K1
=
out_n_k0_ho_wo_k1_global_desc
.
GetLength
(
I4
);
const
auto
K
=
wei_k_c0_y_x_c1_global_desc
.
GetLength
(
I0
);
const
auto
Y
=
wei_k_c0_y_x_c1_global_desc
.
GetLength
(
I2
);
const
auto
X
=
wei_k_c0_y_x_c1_global_desc
.
GetLength
(
I3
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
Hop
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
*
HoPerBlock
;
const
auto
Wop
=
(
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
*
WoPerBlock
;
const
auto
OutRightPadH
=
Hop
-
Ho
;
const
auto
OutRightPadW
=
Wop
-
Wo
;
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
]
+
OutRightPadH
*
ConvStrideH
;
const
auto
InRightPadW
=
in_right_pads
[
I1
]
+
OutRightPadW
*
ConvStrideW
;
std
::
cerr
<<
"OutRightPadH = "
<<
OutRightPadH
<<
" OutRightPadW = "
<<
OutRightPadW
<<
std
::
endl
;
std
::
cerr
<<
"InRightPadH = "
<<
InRightPadH
<<
" InRightPadW = "
<<
InRightPadW
<<
std
::
endl
;
const
auto
E
=
C0
*
Y
*
X
;
static_assert
(
E2
==
C1
,
""
);
const
auto
E0
=
E
/
E1
;
// weight tensor
const
auto
a_e0_k_e2_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C0
*
Y
*
X
,
E2
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C0
*
Y
*
X
),
make_pass_through_transform
(
E2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{},
Sequence
<
2
>
{}));
const
auto
a_e0_e1_k_e2_grid_desc
=
transform_tensor_descriptor
(
a_e0_k_e2_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
E0
,
E1
)),
make_pass_through_transform
(
K
),
make_pass_through_transform
(
E2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
// input tensor
const
auto
in_n_c0_hip_wip_c1_global_desc
=
transform_tensor_descriptor
(
in_n_c0_hi_wi_c1_global_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C0
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
E2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_c0_y_ho_x_wo_c1_global_desc
=
transform_tensor_descriptor
(
in_n_c0_hip_wip_c1_global_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C0
),
make_embed_transform
(
make_tuple
(
Y
,
Hop
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wop
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
E2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{},
Sequence
<
6
>
{}));
const
auto
b_e0_n_ho_wo_e2_grid_desc
=
transform_tensor_descriptor
(
in_n_c0_y_ho_x_wo_c1_global_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C0
,
Y
,
X
)),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
Hop
),
make_pass_through_transform
(
Wop
),
make_pass_through_transform
(
E2
)),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
b_e0_e1_n_ho_wo_e2_grid_desc
=
transform_tensor_descriptor
(
b_e0_n_ho_wo_e2_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
E0
,
E1
)),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
Hop
),
make_pass_through_transform
(
Wop
),
make_pass_through_transform
(
E2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}));
// output tensor
const
auto
c_k_n_hop_wop_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
N
),
make_pad_transform
(
Ho
,
0
,
OutRightPadH
),
make_pad_transform
(
Wo
,
0
,
OutRightPadW
)),
make_tuple
(
Sequence
<
1
,
4
>
{},
Sequence
<
0
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
std
::
cerr
<<
"Hop = "
<<
Hop
<<
" Wop = "
<<
Wop
<<
std
::
endl
;
if
(
!
((
K
%
KPerBlock
)
==
0
&&
(
Hop
%
HoPerBlock
)
==
0
&&
(
Wop
%
WoPerBlock
)
==
0
&&
(
E1
%
E1PerBlock
)
==
0
))
{
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
}
// hack to control index calculation when iterating over a_k_m_global tensor
constexpr
auto
a_e0_e1_k_e2_global_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
constexpr
auto
a_e0_e1_k_e2_global_move_slice_window_step_hack
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
b_e0_e1_n_ho_wo_e2_global_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
constexpr
auto
b_e0_e1_n_ho_wo_e2_global_move_slice_window_step_hack
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format
constexpr
auto
c_k_n_ho_wo_global_tensor_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
// GEMM
using
GridwiseGemm
=
GridwiseGemmDlops_km_kn_mn_v3
<
BlockSize
,
FloatAB
,
FloatAcc
,
FloatC
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
a_e0_e1_k_e2_grid_desc
),
decltype
(
b_e0_e1_n_ho_wo_e2_grid_desc
),
decltype
(
c_k_n_hop_wop_grid_desc
),
E1
,
E2
,
KPerBlock
,
HoPerBlock
,
WoPerBlock
,
E1PerBlock
,
KPerThread
,
HoPerThread
,
WoPerThread
,
EPerThread
,
ABlockTransferThreadSliceLengths_E0_E1_K_E2
,
ABlockTransferThreadClusterLengths_E0_E1_K_E2
,
Sequence
<
2
,
0
,
1
,
3
>
,
Sequence
<
2
,
0
,
1
,
3
>
,
3
,
ABlockTransferSrcScalarPerVector_E2
,
ABlockTransferDstScalarPerVector_E2
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
0
,
2
,
3
,
4
,
1
,
5
>
,
5
,
BThreadTransferSrcScalarPerVector_E2
,
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence
<
2
,
3
,
1
,
0
>
,
0
,
CThreadTransferDstScalarPerVector_K
,
decltype
(
a_e0_e1_k_e2_global_step_hacks
),
decltype
(
b_e0_e1_n_ho_wo_e2_global_step_hacks
),
decltype
(
c_k_n_ho_wo_global_tensor_step_hacks
),
decltype
(
a_e0_e1_k_e2_global_move_slice_window_step_hack
),
decltype
(
b_e0_e1_n_ho_wo_e2_global_move_slice_window_step_hack
)
>
;
using
AGridDesc_E0_E1_K_E2
=
decltype
(
a_e0_e1_k_e2_grid_desc
);
using
BGridDesc_E0_E1_N_Ho_Wo_E2
=
decltype
(
b_e0_e1_n_ho_wo_e2_grid_desc
);
using
CGridDesc_K_N_Ho_Wo
=
decltype
(
c_k_n_hop_wop_grid_desc
);
const
auto
grid_size
=
(
K
/
KPerBlock
)
*
(
Hop
/
HoPerBlock
)
*
(
Wop
/
WoPerBlock
)
*
N
;
const
bool
has_main_k_block_loop
=
(
E1
+
E1PerBlock
)
/
(
2
*
E1PerBlock
)
>
1
;
const
bool
has_double_tail_k_block_loop
=
(
E1
/
E1PerBlock
)
%
2
==
0
;
std
::
cerr
<<
"has_main_k_block_loop = "
<<
has_main_k_block_loop
<<
" has_double_tail_k_block_loop = "
<<
has_double_tail_k_block_loop
<<
std
::
endl
;
const
auto
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
I0
,
I0
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
using
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
=
decltype
(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
);
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float
ave_time
=
0
;
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_gemm_dlops_v2
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo_E2
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
true
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_e0_e1_k_e2_grid_desc
,
b_e0_e1_n_ho_wo_e2_grid_desc
,
c_k_n_hop_wop_grid_desc
,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
);
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_gemm_dlops_v2
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo_E2
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
true
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_e0_e1_k_e2_grid_desc
,
b_e0_e1_n_ho_wo_e2_grid_desc
,
c_k_n_hop_wop_grid_desc
,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
);
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_gemm_dlops_v2
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo_E2
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
false
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_e0_e1_k_e2_grid_desc
,
b_e0_e1_n_ho_wo_e2_grid_desc
,
c_k_n_hop_wop_grid_desc
,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
);
}
else
{
const
auto
kernel
=
kernel_gemm_dlops_v2
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo_E2
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
false
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_e0_e1_k_e2_grid_desc
,
b_e0_e1_n_ho_wo_e2_grid_desc
,
c_k_n_hop_wop_grid_desc
,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
);
}
return
ave_time
;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem
a_e0_e1_k_e2_grid_desc_dev_buf
(
sizeof
(
AGridDesc_E0_E1_K_E2
));
DeviceMem
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf
(
sizeof
(
BGridDesc_E0_E1_N_Ho_Wo_E2
));
DeviceMem
c_k_n_hop_wop_grid_desc_dev_buf
(
sizeof
(
CGridDesc_K_N_Ho_Wo
));
DeviceMem
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf
(
sizeof
(
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
));
a_e0_e1_k_e2_grid_desc_dev_buf
.
ToDevice
(
&
a_e0_e1_k_e2_grid_desc
);
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf
.
ToDevice
(
&
b_e0_e1_n_ho_wo_e2_grid_desc
);
c_k_n_hop_wop_grid_desc_dev_buf
.
ToDevice
(
&
c_k_n_hop_wop_grid_desc
);
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf
.
ToDevice
(
&
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
);
float
ave_time
=
0
;
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_gemm_dlops_v2
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo_E2
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
true
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_e0_e1_k_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_k_n_hop_wop_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_gemm_dlops_v2
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo_E2
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
true
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_e0_e1_k_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_k_n_hop_wop_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_gemm_dlops_v2
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo_E2
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
false
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_e0_e1_k_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_k_n_hop_wop_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
else
{
const
auto
kernel
=
kernel_gemm_dlops_v2
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo_E2
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
false
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_e0_e1_k_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_k_n_hop_wop_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
return
ave_time
;
#endif
}
};
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment