Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
ccc4a1d3
Unverified
Commit
ccc4a1d3
authored
Aug 16, 2021
by
Chao Liu
Committed by
GitHub
Aug 16, 2021
Browse files
Merge pull request #8 from ROCmSoftwarePlatform/miopen_downstream_init_integration
parents
3b866461
16effa76
Changes
144
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1145 additions
and
6129 deletions
+1145
-6129
composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp
...tion_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp
+357
-0
composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp
...ution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp
+64
-51
external/half/include/half.hpp
external/half/include/half.hpp
+0
-5670
host/CMakeLists.txt
host/CMakeLists.txt
+0
-2
host/driver_offline/CMakeLists.txt
host/driver_offline/CMakeLists.txt
+2
-3
host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp
...ackward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp
+22
-33
host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp
...kward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp
+22
-33
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp
...ution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp
+23
-32
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+7
-10
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp
...ion_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp
+23
-34
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
...on_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
+23
-32
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp
...on_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp
+21
-32
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp
...on_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp
+21
-24
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
...on_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
+22
-33
host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
...ution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
+7
-9
host/driver_offline/include/device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp
...ution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp
+23
-26
host/driver_offline/include/driver_contraction_dlops_v1r2.hpp
.../driver_offline/include/driver_contraction_dlops_v1r2.hpp
+33
-37
host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
...ution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
+31
-34
host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp
...orward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp
+31
-34
host/driver_offline/include/driver_gemm_dlops_v1r2.hpp
host/driver_offline/include/driver_gemm_dlops_v1r2.hpp
+413
-0
No files found.
composable_kernel/src/kernel_wrapper/
dynamic_
convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp
→
composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp
View file @
ccc4a1d3
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_
dynamic_
gemm_xdlops_v2r3.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
using
namespace
ck
;
using
namespace
ck
;
...
@@ -60,8 +60,7 @@ using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDst
...
@@ -60,8 +60,7 @@ using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDst
constexpr
index_t
CThreadTransferSrcDstVectorDim
=
CK_PARAM_CThreadTransferSrcDstVectorDim
;
constexpr
index_t
CThreadTransferSrcDstVectorDim
=
CK_PARAM_CThreadTransferSrcDstVectorDim
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
CK_PARAM_CThreadTransferDstScalarPerVector
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
CK_PARAM_CThreadTransferDstScalarPerVector
;
extern
"C"
__global__
void
extern
"C"
__global__
void
convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare
(
dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare
(
int
n
,
int
n
,
int
hi
,
int
hi
,
int
wi
,
int
wi
,
...
@@ -89,12 +88,9 @@ dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare(
...
@@ -89,12 +88,9 @@ dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare(
const
index_t
ho
=
(
hi
+
leftPadH
+
rightPadH
-
convDilationY
*
(
y
-
1
)
-
1
)
/
convStrideH
+
1
;
const
index_t
ho
=
(
hi
+
leftPadH
+
rightPadH
-
convDilationY
*
(
y
-
1
)
-
1
)
/
convStrideH
+
1
;
const
index_t
wo
=
(
wi
+
leftPadW
+
rightPadW
-
convDilationX
*
(
x
-
1
)
-
1
)
/
convStrideW
+
1
;
const
index_t
wo
=
(
wi
+
leftPadW
+
rightPadW
-
convDilationX
*
(
x
-
1
)
-
1
)
/
convStrideW
+
1
;
const
auto
in_n_hi_wi_c_desc
=
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
n
,
hi
,
wi
,
c
));
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
n
,
hi
,
wi
,
c
));
const
auto
wei_k_y_x_c_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
k
,
y
,
x
,
c
));
const
auto
wei_k_y_x_c_desc
=
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
n
,
ho
,
wo
,
k
));
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
k
,
y
,
x
,
c
));
const
auto
out_n_ho_wo_k_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
n
,
ho
,
wo
,
k
));
const
auto
descs
=
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad
(
const
auto
descs
=
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad
(
in_n_hi_wi_c_desc
,
in_n_hi_wi_c_desc
,
...
@@ -114,12 +110,12 @@ dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare(
...
@@ -114,12 +110,12 @@ dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare(
using
BK0NK1GridDesc
=
decltype
(
b_k0_n_k1_grid_desc
);
using
BK0NK1GridDesc
=
decltype
(
b_k0_n_k1_grid_desc
);
using
CMNGridDesc
=
decltype
(
c_m_n_grid_desc
);
using
CMNGridDesc
=
decltype
(
c_m_n_grid_desc
);
using
BGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
using
BGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{})));
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{})));
using
AGrid
I
te
rator
Hacks
=
using
AGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
...
@@ -127,68 +123,68 @@ dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare(
...
@@ -127,68 +123,68 @@ dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{})));
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{})));
using
CGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
using
CGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{})));
Sequence
<
0
,
0
,
2
,
0
,
0
>
{})));
using
AGridMoveSliceWindow
I
te
rator
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
;
using
AGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
;
using
BGridMoveSliceWindow
I
te
rator
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
;
using
BGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
;
using
GridwiseGemm
=
using
GridwiseGemm
=
Gridwise
Dynamic
Gemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
BlockSize
,
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
FloatC
,
FloatC
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
AK0MK1GridDesc
,
AK0MK1GridDesc
,
BK0NK1GridDesc
,
BK0NK1GridDesc
,
CMNGridDesc
,
CMNGridDesc
,
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
KPerBlock
,
KPerBlock
,
MPerWave
,
MPerWave
,
NPerWave
,
NPerWave
,
K1
,
K1
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
AGrid
I
te
rator
Hacks
,
AGrid
S
te
p
Hacks
,
BGrid
I
te
rator
Hacks
,
BGrid
S
te
p
Hacks
,
CGrid
I
te
rator
Hacks
,
CGrid
S
te
p
Hacks
,
AGridMoveSliceWindow
I
te
rator
Hacks
,
AGridMoveSliceWindow
S
te
p
Hacks
,
BGridMoveSliceWindow
I
te
rator
Hacks
,
BGridMoveSliceWindow
S
te
p
Hacks
,
false
>
;
false
>
;
auto
c_m0_m1_m2_n_grid_desc
=
GridwiseGemm
::
MakeCM0M1M2NGridDescriptor
(
c_m_n_grid_desc
);
auto
c_m0_m1_m2_n_grid_desc
=
GridwiseGemm
::
MakeCM0M1M2NGridDescriptor
(
c_m_n_grid_desc
);
...
@@ -212,7 +208,7 @@ extern "C" __global__ void
...
@@ -212,7 +208,7 @@ extern "C" __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
dynamic_
convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk
(
convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
...
@@ -225,14 +221,13 @@ extern "C" __global__ void
...
@@ -225,14 +221,13 @@ extern "C" __global__ void
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_n_hi_wi_c_desc
=
constexpr
auto
in_n_hi_wi_c_desc
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
256
,
28
,
28
,
256
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
256
,
28
,
28
,
256
));
constexpr
auto
wei_k_y_x_c_desc
=
constexpr
auto
wei_k_y_x_c_desc
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
256
,
3
,
3
,
256
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
256
,
3
,
3
,
256
));
constexpr
auto
out_n_ho_wo_k_desc
=
constexpr
auto
out_n_ho_wo_k_desc
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
256
,
28
,
28
,
256
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
256
,
28
,
28
,
256
));
constexpr
auto
descs
=
constexpr
auto
descs
=
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad
(
in_n_hi_wi_c_desc
,
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad
(
in_n_hi_wi_c_desc
,
...
@@ -252,12 +247,12 @@ extern "C" __global__ void
...
@@ -252,12 +247,12 @@ extern "C" __global__ void
using
BK0NK1GridDesc
=
decltype
(
b_k0_n_k1_grid_desc_tmp
);
using
BK0NK1GridDesc
=
decltype
(
b_k0_n_k1_grid_desc_tmp
);
using
CMNGridDesc
=
decltype
(
c_m_n_grid_desc
);
using
CMNGridDesc
=
decltype
(
c_m_n_grid_desc
);
using
BGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
using
BGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{})));
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{})));
using
AGrid
I
te
rator
Hacks
=
using
AGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
...
@@ -265,68 +260,68 @@ extern "C" __global__ void
...
@@ -265,68 +260,68 @@ extern "C" __global__ void
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{})));
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{})));
using
CGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
using
CGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{})));
Sequence
<
0
,
0
,
2
,
0
,
0
>
{})));
using
AGridMoveSliceWindow
I
te
rator
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
;
using
AGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
;
using
BGridMoveSliceWindow
I
te
rator
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
;
using
BGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
;
using
GridwiseGemm
=
using
GridwiseGemm
=
Gridwise
Dynamic
Gemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
BlockSize
,
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
FloatC
,
FloatC
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
AK0MK1GridDesc
,
AK0MK1GridDesc
,
BK0NK1GridDesc
,
BK0NK1GridDesc
,
CMNGridDesc
,
CMNGridDesc
,
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
KPerBlock
,
KPerBlock
,
MPerWave
,
MPerWave
,
NPerWave
,
NPerWave
,
K1
,
K1
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
AGrid
I
te
rator
Hacks
,
AGrid
S
te
p
Hacks
,
BGrid
I
te
rator
Hacks
,
BGrid
S
te
p
Hacks
,
CGrid
I
te
rator
Hacks
,
CGrid
S
te
p
Hacks
,
AGridMoveSliceWindow
I
te
rator
Hacks
,
AGridMoveSliceWindow
S
te
p
Hacks
,
BGridMoveSliceWindow
I
te
rator
Hacks
,
BGridMoveSliceWindow
S
te
p
Hacks
,
false
>
;
false
>
;
constexpr
auto
c_m0_m1_m2_n_grid_desc_tmp
=
constexpr
auto
c_m0_m1_m2_n_grid_desc_tmp
=
GridwiseGemm
::
MakeCM0M1M2NGridDescriptor
(
c_m_n_grid_desc
);
GridwiseGemm
::
MakeCM0M1M2NGridDescriptor
(
c_m_n_grid_desc
);
constexpr
auto
c_blockid_to_m0_n0_block_cluster_adaptor_tmp
=
constexpr
auto
c_blockid_to_m0_n0_block_cluster_adaptor_tmp
=
...
...
composable_kernel/src/kernel_wrapper/
dynamic_
convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp
→
composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp
View file @
ccc4a1d3
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_
dynamic_
contraction_dlops_v1r2.hpp"
#include "gridwise_contraction_dlops_v1r2.hpp"
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
using
namespace
ck
;
using
namespace
ck
;
...
@@ -62,23 +62,39 @@ constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HasMainKBloc
...
@@ -62,23 +62,39 @@ constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HasMainKBloc
constexpr
bool
HasDoubleTailKBlockLoop
=
static_cast
<
bool
>
(
CK_PARAM_HasDoubleTailKBlockLoop
);
constexpr
bool
HasDoubleTailKBlockLoop
=
static_cast
<
bool
>
(
CK_PARAM_HasDoubleTailKBlockLoop
);
extern
"C"
__global__
void
extern
"C"
__global__
void
dynamic_
convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare
(
in
dex_
t
N
,
convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare
(
int
N
_
,
index_
t
C
,
in
t
C
_
,
index_
t
Hi
,
in
t
Hi
_
,
index_
t
Wi
,
in
t
Wi
_
,
index_
t
K
,
in
t
K
_
,
index_
t
Y
,
in
t
Y
_
,
index_
t
X
,
in
t
X
_
,
index_
t
ConvStrideH
,
in
t
ConvStrideH
_
,
index_
t
ConvStrideW
,
in
t
ConvStrideW
_
,
index_
t
ConvDilationH
,
in
t
ConvDilationH
_
,
index_
t
ConvDilationW
,
in
t
ConvDilationW
_
,
index_
t
InLeftPadH
,
in
t
InLeftPadH
_
,
index_
t
InLeftPadW
,
in
t
InLeftPadW
_
,
index_
t
InRightPadH
,
in
t
InRightPadH
_
,
index_
t
InRightPadW
,
in
t
InRightPadW
_
,
void
*
p_desc_tuple
)
void
*
p_desc_tuple
)
{
{
index_t
N
=
static_cast
<
index_t
>
(
N_
);
index_t
C
=
static_cast
<
index_t
>
(
C_
);
index_t
Hi
=
static_cast
<
index_t
>
(
Hi_
);
index_t
Wi
=
static_cast
<
index_t
>
(
Wi_
);
index_t
K
=
static_cast
<
index_t
>
(
K_
);
index_t
Y
=
static_cast
<
index_t
>
(
Y_
);
index_t
X
=
static_cast
<
index_t
>
(
X_
);
index_t
ConvStrideH
=
static_cast
<
index_t
>
(
ConvStrideH_
);
index_t
ConvStrideW
=
static_cast
<
index_t
>
(
ConvStrideW_
);
index_t
ConvDilationH
=
static_cast
<
index_t
>
(
ConvDilationH_
);
index_t
ConvDilationW
=
static_cast
<
index_t
>
(
ConvDilationW_
);
index_t
InLeftPadH
=
static_cast
<
index_t
>
(
InLeftPadH_
);
index_t
InLeftPadW
=
static_cast
<
index_t
>
(
InLeftPadW_
);
index_t
InRightPadH
=
static_cast
<
index_t
>
(
InRightPadH_
);
index_t
InRightPadW
=
static_cast
<
index_t
>
(
InRightPadW_
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
...
@@ -88,12 +104,9 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde
...
@@ -88,12 +104,9 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde
const
index_t
Wo
=
const
index_t
Wo
=
(
Wi
+
InLeftPadW
+
InRightPadW
-
ConvDilationW
*
(
X
-
1
)
-
1
)
/
ConvStrideW
+
1
;
(
Wi
+
InLeftPadW
+
InRightPadW
-
ConvDilationW
*
(
X
-
1
)
-
1
)
/
ConvStrideW
+
1
;
const
auto
in_n_c_hi_wi_desc
=
const
auto
in_n_c_hi_wi_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
C
,
Hi
,
Wi
));
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
C
,
Hi
,
Wi
));
const
auto
wei_k_c_y_x_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
,
Y
,
X
));
const
auto
wei_k_c_y_x_desc
=
const
auto
out_n_k_ho_wo_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K
,
Ho
,
Wo
));
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
K
,
C
,
Y
,
X
));
const
auto
out_n_k_ho_wo_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K
,
Ho
,
Wo
));
const
auto
descs
=
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad
(
const
auto
descs
=
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad
(
wei_k_c_y_x_desc
,
wei_k_c_y_x_desc
,
...
@@ -114,7 +127,7 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde
...
@@ -114,7 +127,7 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde
using
BGridDesc_GK0_GN0_GN1_GK1
=
decltype
(
b_grid_desc_gk0_gn0_gn1_gk1
);
using
BGridDesc_GK0_GN0_GN1_GK1
=
decltype
(
b_grid_desc_gk0_gn0_gn1_gk1
);
using
CGridDesc_GM0_GM1_GN0_GN1
=
decltype
(
c_grid_desc_gm0_gm1_gn0_gn1
);
using
CGridDesc_GM0_GM1_GN0_GN1
=
decltype
(
c_grid_desc_gm0_gm1_gn0_gn1
);
using
AGrid
I
te
rator
Hacks
=
using
AGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GK0
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GM0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GM0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GM10
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GM10
...
@@ -126,7 +139,7 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde
...
@@ -126,7 +139,7 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: GM11
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: GM11
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{})));
// 4-: GK1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{})));
// 4-: GK1
using
BGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
using
BGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GK0
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GN0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GN0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GN10
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GN10
...
@@ -138,7 +151,7 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde
...
@@ -138,7 +151,7 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: GN11
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: GN11
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{})));
// 4-: GK1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{})));
// 4-: GK1
using
CGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
using
CGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GM10
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GM10
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: BM0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: BM0
...
@@ -154,13 +167,13 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde
...
@@ -154,13 +167,13 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 4-: BN0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 4-: BN0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{})));
// 5-: GN1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{})));
// 5-: GN1
using
AGridMoveSliceWindow
I
te
rator
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
;
using
AGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
;
using
BGridMoveSliceWindow
I
te
rator
Hacks
=
using
BGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
;
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
;
using
GridwiseContraction
=
using
GridwiseContraction
=
Gridwise
Dynamic
ContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
<
GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
<
BlockSize
,
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
...
@@ -194,11 +207,11 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde
...
@@ -194,11 +207,11 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
AGrid
I
te
rator
Hacks
,
AGrid
S
te
p
Hacks
,
BGrid
I
te
rator
Hacks
,
BGrid
S
te
p
Hacks
,
CGrid
I
te
rator
Hacks
,
CGrid
S
te
p
Hacks
,
AGridMoveSliceWindow
I
te
rator
Hacks
,
AGridMoveSliceWindow
S
te
p
Hacks
,
BGridMoveSliceWindow
I
te
rator
Hacks
>
;
BGridMoveSliceWindow
S
te
p
Hacks
>
;
if
(
get_block_1d_id
()
==
0
&&
get_thread_local_1d_id
()
==
0
)
if
(
get_block_1d_id
()
==
0
&&
get_thread_local_1d_id
()
==
0
)
{
{
...
@@ -220,7 +233,7 @@ extern "C" __global__ void
...
@@ -220,7 +233,7 @@ extern "C" __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
dynamic_
convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw
(
convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
...
@@ -232,11 +245,11 @@ extern "C" __global__ void
...
@@ -232,11 +245,11 @@ extern "C" __global__ void
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_n_c_hi_wi_desc
=
constexpr
auto
in_n_c_hi_wi_desc
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
256
,
256
,
28
,
28
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
256
,
256
,
28
,
28
));
constexpr
auto
wei_k_c_y_x_desc
=
constexpr
auto
wei_k_c_y_x_desc
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
256
,
256
,
3
,
3
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
256
,
256
,
3
,
3
));
constexpr
auto
out_n_k_ho_wo_desc
=
constexpr
auto
out_n_k_ho_wo_desc
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
256
,
256
,
28
,
28
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
256
,
256
,
28
,
28
));
constexpr
auto
descs
=
constexpr
auto
descs
=
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad
(
wei_k_c_y_x_desc
,
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad
(
wei_k_c_y_x_desc
,
...
@@ -257,7 +270,7 @@ extern "C" __global__ void
...
@@ -257,7 +270,7 @@ extern "C" __global__ void
using
BGridDesc_GK0_GN0_GN1_GK1
=
decltype
(
b_grid_desc_gk0_gn0_gn1_gk1
);
using
BGridDesc_GK0_GN0_GN1_GK1
=
decltype
(
b_grid_desc_gk0_gn0_gn1_gk1
);
using
CGridDesc_GM0_GM1_GN0_GN1
=
decltype
(
c_grid_desc_gm0_gm1_gn0_gn1
);
using
CGridDesc_GM0_GM1_GN0_GN1
=
decltype
(
c_grid_desc_gm0_gm1_gn0_gn1
);
using
AGrid
I
te
rator
Hacks
=
using
AGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GK0
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GM0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GM0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GM10
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GM10
...
@@ -269,7 +282,7 @@ extern "C" __global__ void
...
@@ -269,7 +282,7 @@ extern "C" __global__ void
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: GM11
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: GM11
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{})));
// 4-: GK1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{})));
// 4-: GK1
using
BGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
using
BGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GK0
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GN0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GN0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GN10
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GN10
...
@@ -281,7 +294,7 @@ extern "C" __global__ void
...
@@ -281,7 +294,7 @@ extern "C" __global__ void
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: GN11
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: GN11
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{})));
// 4-: GK1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{})));
// 4-: GK1
using
CGrid
I
te
rator
Hacks
=
decltype
(
make_tuple
(
using
CGrid
S
te
p
Hacks
=
decltype
(
make_tuple
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GM10
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GM10
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: BM0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: BM0
...
@@ -297,13 +310,13 @@ extern "C" __global__ void
...
@@ -297,13 +310,13 @@ extern "C" __global__ void
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 4-: BN0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 4-: BN0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{})));
// 5-: GN1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{})));
// 5-: GN1
using
AGridMoveSliceWindow
I
te
rator
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
;
using
AGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
;
using
BGridMoveSliceWindow
I
te
rator
Hacks
=
using
BGridMoveSliceWindow
S
te
p
Hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
;
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
;
using
GridwiseContraction
=
using
GridwiseContraction
=
Gridwise
Dynamic
ContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
<
GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
<
BlockSize
,
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
...
@@ -337,11 +350,11 @@ extern "C" __global__ void
...
@@ -337,11 +350,11 @@ extern "C" __global__ void
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
AGrid
I
te
rator
Hacks
,
AGrid
S
te
p
Hacks
,
BGrid
I
te
rator
Hacks
,
BGrid
S
te
p
Hacks
,
CGrid
I
te
rator
Hacks
,
CGrid
S
te
p
Hacks
,
AGridMoveSliceWindow
I
te
rator
Hacks
,
AGridMoveSliceWindow
S
te
p
Hacks
,
BGridMoveSliceWindow
I
te
rator
Hacks
>
;
BGridMoveSliceWindow
S
te
p
Hacks
>
;
using
AGridDesc_GK0_GM0_GM10_GM11_GK1
=
using
AGridDesc_GK0_GM0_GM10_GM11_GK1
=
decltype
(
GridwiseContraction
::
MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1
(
decltype
(
GridwiseContraction
::
MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1
(
...
...
external/half/include/half.hpp
deleted
100644 → 0
View file @
3b866461
This source diff could not be displayed because it is too large. You can
view the blob
instead.
host/CMakeLists.txt
View file @
ccc4a1d3
add_subdirectory
(
host_tensor
)
add_subdirectory
(
host_tensor
)
add_subdirectory
(
online_compile
)
add_subdirectory
(
driver_offline
)
add_subdirectory
(
driver_offline
)
add_subdirectory
(
driver_online
)
host/driver_offline/CMakeLists.txt
View file @
ccc4a1d3
...
@@ -9,11 +9,10 @@ include_directories(BEFORE
...
@@ -9,11 +9,10 @@ include_directories(BEFORE
${
PROJECT_SOURCE_DIR
}
/composable_kernel/include/problem_transform
${
PROJECT_SOURCE_DIR
}
/composable_kernel/include/problem_transform
${
PROJECT_SOURCE_DIR
}
/composable_kernel/include/driver
${
PROJECT_SOURCE_DIR
}
/composable_kernel/include/driver
${
PROJECT_SOURCE_DIR
}
/external/rocm/include
${
PROJECT_SOURCE_DIR
}
/external/rocm/include
${
PROJECT_SOURCE_DIR
}
/external/half/include
)
)
set
(
CONV_FWD_DRIVER_OFFLINE_SOURCE conv_fwd_driver_offline.cpp
)
set
(
CONV_FWD_DRIVER_OFFLINE_SOURCE
src/
conv_fwd_driver_offline.cpp
)
set
(
CONV_BWD_DRIVER_OFFLINE_SOURCE conv_bwd_driver_offline.cpp
)
set
(
CONV_BWD_DRIVER_OFFLINE_SOURCE
src/
conv_bwd_driver_offline.cpp
)
add_executable
(
conv_fwd_driver_offline
${
CONV_FWD_DRIVER_OFFLINE_SOURCE
}
)
add_executable
(
conv_fwd_driver_offline
${
CONV_FWD_DRIVER_OFFLINE_SOURCE
}
)
add_executable
(
conv_bwd_driver_offline
${
CONV_BWD_DRIVER_OFFLINE_SOURCE
}
)
add_executable
(
conv_bwd_driver_offline
${
CONV_BWD_DRIVER_OFFLINE_SOURCE
}
)
...
...
host/driver_offline/include/device_
dynamic_
convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp
→
host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp
View file @
ccc4a1d3
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp"
#include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp"
#include "driver_
dynamic_
gemm_xdlops_v2r3.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
TInWei
,
template
<
typename
TInWei
,
typename
TAcc
,
typename
TAcc
,
...
@@ -14,7 +14,7 @@ template <typename TInWei,
...
@@ -14,7 +14,7 @@ template <typename TInWei,
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
>
typename
InRightPads
>
void
device_
dynamic_
convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk
(
void
device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk
(
const
InLengths
&
in_n_hi_wi_c_lengths
,
const
InLengths
&
in_n_hi_wi_c_lengths
,
const
WeiLengths
&
wei_k_y_x_c_lengths
,
const
WeiLengths
&
wei_k_y_x_c_lengths
,
const
OutLengths
&
out_n_ho_wo_k_lengths
,
const
OutLengths
&
out_n_ho_wo_k_lengths
,
...
@@ -35,11 +35,6 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
...
@@ -35,11 +35,6 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
auto
I8
=
Number
<
8
>
{};
DeviceMem
in_n_hi_wi_c_device_buf
(
sizeof
(
TInWei
)
*
in_n_hi_wi_c
.
mDesc
.
GetElementSpace
());
DeviceMem
in_n_hi_wi_c_device_buf
(
sizeof
(
TInWei
)
*
in_n_hi_wi_c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_y_x_c_device_buf
(
sizeof
(
TInWei
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_y_x_c_device_buf
(
sizeof
(
TInWei
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
());
...
@@ -49,12 +44,9 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
...
@@ -49,12 +44,9 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
const
auto
in_n_hi_wi_c_desc
=
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor_packed
(
in_n_hi_wi_c_lengths
);
make_dynamic_naive_tensor_descriptor_packed_v2
(
in_n_hi_wi_c_lengths
);
const
auto
wei_k_y_x_c_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_y_x_c_lengths
);
const
auto
wei_k_y_x_c_desc
=
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
make_dynamic_naive_tensor_descriptor_packed_v2
(
wei_k_y_x_c_lengths
);
const
auto
out_n_ho_wo_k_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
out_n_ho_wo_k_lengths
);
#if 1
#if 1
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
...
@@ -215,7 +207,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
...
@@ -215,7 +207,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
const
auto
in_gemmm_gemmn_grid_desc
=
descs
[
I2
];
const
auto
in_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
=
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: gemmk0
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: gemmk0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: gemmm
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: gemmm
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: gemmk1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: gemmk1
...
@@ -223,7 +215,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
...
@@ -223,7 +215,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: Gemmm
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: Gemmm
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: Gemmk1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: Gemmk1
constexpr
auto
out_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
=
make_tuple
(
constexpr
auto
out_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: gemmk0
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: gemmk0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
// 1+: gemmn
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
// 1+: gemmn
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: gemmk1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: gemmk1
...
@@ -231,7 +223,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
...
@@ -231,7 +223,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
// 1-: gemmn
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
// 1-: gemmn
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: gemmk1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: gemmk1
constexpr
auto
in_m0_m1_m2_n_grid_
i
te
rator
_hacks
=
make_tuple
(
constexpr
auto
in_m0_m1_m2_n_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: MRepeat
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: MRepeat
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 1+: NRepeat
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 1+: NRepeat
...
@@ -251,15 +243,15 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
...
@@ -251,15 +243,15 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{}));
// 7-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{}));
// 7-: N1
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{};
constexpr
auto
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
{
float
ave_time
=
driver_
dynamic_
gemm_xdlops_v2r3
<
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
BlockSize
,
TInWei
,
TInWei
,
TAcc
,
TAcc
,
...
@@ -295,11 +287,11 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
...
@@ -295,11 +287,11 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
Sequence
<
1
,
3
,
7
,
0
,
2
,
4
,
5
,
6
>
,
Sequence
<
1
,
3
,
7
,
0
,
2
,
4
,
5
,
6
>
,
6
,
6
,
GemmCThreadTransferDstScalarPerVector
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
),
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
),
decltype
(
in_m0_m1_m2_n_grid_
i
te
rator
_hacks
),
decltype
(
in_m0_m1_m2_n_grid_
s
te
p
_hacks
),
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
),
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
),
false
// CAccessOrderMRepeatNRepeat
false
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
>
(
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
...
@@ -307,11 +299,11 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
...
@@ -307,11 +299,11 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
wei_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmk0_gemmn_gemmk1_grid_desc
,
out_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
,
in_gemmm_gemmn_grid_desc
,
wei_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
,
wei_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
,
out_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
,
out_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
,
in_m0_m1_m2_n_grid_
i
te
rator
_hacks
,
in_m0_m1_m2_n_grid_
s
te
p
_hacks
,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
,
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
,
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
,
nrepeat
);
nrepeat
);
{
{
...
@@ -319,16 +311,13 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
...
@@ -319,16 +311,13 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
const
auto
K
=
out_n_ho_wo_k_lengths
[
I3
];
const
auto
K
=
out_n_ho_wo_k_lengths
[
I3
];
const
auto
C
=
wei_k_y_x_c_lengths
[
I3
];
const
auto
C
=
wei_k_y_x_c_lengths
[
I3
];
const
auto
Hi
=
in_n_hi_wi_c_lengths
[
I1
];
const
auto
Wi
=
in_n_hi_wi_c_lengths
[
I2
];
const
auto
Ho
=
out_n_ho_wo_k_lengths
[
I1
];
const
auto
Ho
=
out_n_ho_wo_k_lengths
[
I1
];
const
auto
Wo
=
out_n_ho_wo_k_lengths
[
I2
];
const
auto
Wo
=
out_n_ho_wo_k_lengths
[
I2
];
const
auto
Y
=
wei_k_y_x_c_lengths
[
I1
];
const
auto
Y
=
wei_k_y_x_c_lengths
[
I1
];
const
auto
X
=
wei_k_y_x_c_lengths
[
I2
];
const
auto
X
=
wei_k_y_x_c_lengths
[
I2
];
float
perf
=
(
float
)
(
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
)
/
float
perf
=
static_cast
<
float
>
(
(
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
)
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
...
...
host/driver_offline/include/device_
dynamic_
convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp
→
host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp
View file @
ccc4a1d3
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp"
#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp"
#include "driver_
dynamic_
gemm_xdlops_v2r3.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
TInWei
,
template
<
typename
TInWei
,
typename
TAcc
,
typename
TAcc
,
...
@@ -14,7 +14,7 @@ template <typename TInWei,
...
@@ -14,7 +14,7 @@ template <typename TInWei,
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
>
typename
InRightPads
>
void
device_
dynamic_
convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
(
void
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
(
const
InLengths
&
in_n_hi_wi_c_lengths
,
const
InLengths
&
in_n_hi_wi_c_lengths
,
const
WeiLengths
&
wei_k_y_x_c_lengths
,
const
WeiLengths
&
wei_k_y_x_c_lengths
,
const
OutLengths
&
out_n_ho_wo_k_lengths
,
const
OutLengths
&
out_n_ho_wo_k_lengths
,
...
@@ -35,11 +35,6 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
...
@@ -35,11 +35,6 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
auto
I8
=
Number
<
8
>
{};
DeviceMem
in_n_hi_wi_c_device_buf
(
sizeof
(
TInWei
)
*
in_n_hi_wi_c
.
mDesc
.
GetElementSpace
());
DeviceMem
in_n_hi_wi_c_device_buf
(
sizeof
(
TInWei
)
*
in_n_hi_wi_c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_y_x_c_device_buf
(
sizeof
(
TInWei
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_y_x_c_device_buf
(
sizeof
(
TInWei
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
());
...
@@ -49,12 +44,9 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
...
@@ -49,12 +44,9 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
const
auto
in_n_hi_wi_c_desc
=
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor_packed
(
in_n_hi_wi_c_lengths
);
make_dynamic_naive_tensor_descriptor_packed_v2
(
in_n_hi_wi_c_lengths
);
const
auto
wei_k_y_x_c_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_y_x_c_lengths
);
const
auto
wei_k_y_x_c_desc
=
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
make_dynamic_naive_tensor_descriptor_packed_v2
(
wei_k_y_x_c_lengths
);
const
auto
out_n_ho_wo_k_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
out_n_ho_wo_k_lengths
);
#if 0
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
...
@@ -187,7 +179,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
...
@@ -187,7 +179,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
const
auto
in_gemmm_gemmn_grid_desc
=
descs
[
I2
];
const
auto
in_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
=
make_tuple
(
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: gemmk0
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: gemmk0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
// 1+: gemmm
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
// 1+: gemmm
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: gemmk1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: gemmk1
...
@@ -195,7 +187,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
...
@@ -195,7 +187,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
// 1-: gemmm
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
// 1-: gemmm
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: gemmk1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: gemmk1
constexpr
auto
wei_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
=
constexpr
auto
wei_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: gemmk0
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: gemmk0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: gemmn
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: gemmn
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: gemmk1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: gemmk1
...
@@ -203,7 +195,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
...
@@ -203,7 +195,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: Gemmn
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: Gemmn
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: Gemmk1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: Gemmk1
constexpr
auto
in_m0_m1_m2_n_grid_
i
te
rator
_hacks
=
make_tuple
(
constexpr
auto
in_m0_m1_m2_n_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
// 0+: MRepeat
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
// 0+: MRepeat
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: NRepeat
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: NRepeat
...
@@ -223,15 +215,15 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
...
@@ -223,15 +215,15 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{},
// 6-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{},
// 6-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N1
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
>
{};
constexpr
auto
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
{
float
ave_time
=
driver_
dynamic_
gemm_xdlops_v2r3
<
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
BlockSize
,
TInWei
,
TInWei
,
TAcc
,
TAcc
,
...
@@ -271,11 +263,11 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
...
@@ -271,11 +263,11 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
#endif
#endif
7
,
7
,
GemmCThreadTransferDstScalarPerVector
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
),
decltype
(
in_m0_m1_m2_n_grid_
i
te
rator
_hacks
),
decltype
(
in_m0_m1_m2_n_grid_
s
te
p
_hacks
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
),
true
// CAccessOrderMRepeatNRepeat
true
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
>
(
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
...
@@ -283,11 +275,11 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
...
@@ -283,11 +275,11 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
out_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
,
in_gemmm_gemmn_grid_desc
,
out_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
,
out_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
,
in_m0_m1_m2_n_grid_
i
te
rator
_hacks
,
in_m0_m1_m2_n_grid_
s
te
p
_hacks
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
,
nrepeat
);
nrepeat
);
{
{
...
@@ -295,16 +287,13 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
...
@@ -295,16 +287,13 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
const
auto
K
=
out_n_ho_wo_k_lengths
[
I3
];
const
auto
K
=
out_n_ho_wo_k_lengths
[
I3
];
const
auto
C
=
wei_k_y_x_c_lengths
[
I3
];
const
auto
C
=
wei_k_y_x_c_lengths
[
I3
];
const
auto
Hi
=
in_n_hi_wi_c_lengths
[
I1
];
const
auto
Wi
=
in_n_hi_wi_c_lengths
[
I2
];
const
auto
Ho
=
out_n_ho_wo_k_lengths
[
I1
];
const
auto
Ho
=
out_n_ho_wo_k_lengths
[
I1
];
const
auto
Wo
=
out_n_ho_wo_k_lengths
[
I2
];
const
auto
Wo
=
out_n_ho_wo_k_lengths
[
I2
];
const
auto
Y
=
wei_k_y_x_c_lengths
[
I1
];
const
auto
Y
=
wei_k_y_x_c_lengths
[
I1
];
const
auto
X
=
wei_k_y_x_c_lengths
[
I2
];
const
auto
X
=
wei_k_y_x_c_lengths
[
I2
];
float
perf
=
(
float
)
(
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
)
/
float
perf
=
static_cast
<
float
>
(
(
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
)
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
...
...
host/driver_offline/include/device_
dynamic_
convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp
→
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp
View file @
ccc4a1d3
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "driver_
dynamic_
gemm_dlops_v1r2.hpp"
#include "driver_gemm_dlops_v1r2.hpp"
template
<
typename
TInWei
,
template
<
typename
TInWei
,
typename
TAcc
,
typename
TAcc
,
...
@@ -14,7 +14,7 @@ template <typename TInWei,
...
@@ -14,7 +14,7 @@ template <typename TInWei,
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
>
typename
InRightPads
>
void
device_
dynamic_
convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw
(
void
device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw
(
const
InLengths
&
in_n_c_hi_wi_lengths
,
const
InLengths
&
in_n_c_hi_wi_lengths
,
const
WeiLengths
&
wei_k_c_y_x_lengths
,
const
WeiLengths
&
wei_k_c_y_x_lengths
,
const
OutLengths
&
out_n_k_ho_wo_lengths
,
const
OutLengths
&
out_n_k_ho_wo_lengths
,
...
@@ -34,12 +34,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
...
@@ -34,12 +34,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
auto
I8
=
Number
<
8
>
{};
DeviceMem
in_n_c_hi_wi_device_buf
(
sizeof
(
TInWei
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
in_n_c_hi_wi_device_buf
(
sizeof
(
TInWei
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c_y_x_device_buf
(
sizeof
(
TInWei
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c_y_x_device_buf
(
sizeof
(
TInWei
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
...
@@ -49,12 +43,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
...
@@ -49,12 +43,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
wei_k_c_y_x_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
wei_k_c_y_x_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
out_n_k_ho_wo_device_buf
.
ToDevice
(
out_n_k_ho_wo
.
mData
.
data
());
out_n_k_ho_wo_device_buf
.
ToDevice
(
out_n_k_ho_wo
.
mData
.
data
());
const
auto
in_n_c_hi_wi_desc
=
const
auto
in_n_c_hi_wi_desc
=
make_naive_tensor_descriptor_packed
(
in_n_c_hi_wi_lengths
);
make_dynamic_naive_tensor_descriptor_packed_v2
(
in_n_c_hi_wi_lengths
);
const
auto
wei_k_c_y_x_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_c_y_x_lengths
);
const
auto
wei_k_c_y_x_desc
=
const
auto
out_n_k_ho_wo_desc
=
make_naive_tensor_descriptor_packed
(
out_n_k_ho_wo_lengths
);
make_dynamic_naive_tensor_descriptor_packed_v2
(
wei_k_c_y_x_lengths
);
const
auto
out_n_k_ho_wo_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
out_n_k_ho_wo_lengths
);
#if 1
#if 1
// cdata = 64, BlockSize = 256, 128x128x8
// cdata = 64, BlockSize = 256, 128x128x8
...
@@ -98,7 +89,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
...
@@ -98,7 +89,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
in_right_pads
);
in_right_pads
);
// HACK: hacks that control index calculation when iterating over A, B, C matrix
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
wei_gemmk_gemmm0_gemmn1_grid_
i
te
rator
_hacks
=
constexpr
auto
wei_gemmk_gemmm0_gemmn1_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
...
@@ -108,7 +99,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
...
@@ -108,7 +99,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
constexpr
auto
in_gemmk_gemmn0_gemmn1_grid_
i
te
rator
_hacks
=
constexpr
auto
in_gemmk_gemmn0_gemmn1_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{}),
...
@@ -116,7 +107,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
...
@@ -116,7 +107,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{}));
constexpr
auto
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_
i
te
rator
_hacks
=
constexpr
auto
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
...
@@ -130,10 +121,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
...
@@ -130,10 +121,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
constexpr
auto
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
{};
const
auto
wei_gemmk_gemmm_grid_desc
=
descs
[
I0
];
const
auto
wei_gemmk_gemmm_grid_desc
=
descs
[
I0
];
...
@@ -142,7 +133,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
...
@@ -142,7 +133,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
{
float
ave_time
=
driver_
dynamic_
gemm_dlops_v1r2
<
float
ave_time
=
driver_gemm_dlops_v1r2
<
BlockSize
,
BlockSize
,
TInWei
,
TInWei
,
TAcc
,
TAcc
,
...
@@ -180,26 +171,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
...
@@ -180,26 +171,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
Sequence
<
3
,
4
,
5
,
0
,
1
,
2
>
,
// CThreadTransferSrcDstAccessOrder
Sequence
<
3
,
4
,
5
,
0
,
1
,
2
>
,
// CThreadTransferSrcDstAccessOrder
5
,
// CThreadTransferSrcDstVectorDim
5
,
// CThreadTransferSrcDstVectorDim
GemmCThreadTransferDstScalarPerVector_N11
,
GemmCThreadTransferDstScalarPerVector_N11
,
decltype
(
wei_gemmk_gemmm0_gemmn1_grid_
i
te
rator
_hacks
),
decltype
(
wei_gemmk_gemmm0_gemmn1_grid_
s
te
p
_hacks
),
decltype
(
in_gemmk_gemmn0_gemmn1_grid_
i
te
rator
_hacks
),
decltype
(
in_gemmk_gemmn0_gemmn1_grid_
s
te
p
_hacks
),
decltype
(
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_
i
te
rator
_hacks
),
decltype
(
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_
s
te
p
_hacks
),
decltype
(
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_
i
te
rator
_hacks
),
decltype
(
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_
s
te
p
_hacks
),
decltype
(
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_
i
te
rator
_hacks
)
>
(
decltype
(
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_
s
te
p
_hacks
)
>
(
static_cast
<
TInWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
wei_gemmk_gemmm_grid_desc
,
wei_gemmk_gemmm_grid_desc
,
in_gemmk_gemmn_grid_desc
,
in_gemmk_gemmn_grid_desc
,
out_gemmm_gemmn_grid_desc
,
out_gemmm_gemmn_grid_desc
,
wei_gemmk_gemmm0_gemmn1_grid_
i
te
rator
_hacks
,
wei_gemmk_gemmm0_gemmn1_grid_
s
te
p
_hacks
,
in_gemmk_gemmn0_gemmn1_grid_
i
te
rator
_hacks
,
in_gemmk_gemmn0_gemmn1_grid_
s
te
p
_hacks
,
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_
i
te
rator
_hacks
,
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_
s
te
p
_hacks
,
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_
i
te
rator
_hacks
,
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_
s
te
p
_hacks
,
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_
i
te
rator
_hacks
,
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_
s
te
p
_hacks
,
nrepeat
);
nrepeat
);
float
perf
=
(
float
)
calculate_convolution_flops
(
float
perf
=
static_cast
<
float
>
(
calculate_convolution_flops
(
in_n_c_hi_wi_desc
,
wei_k_c_y_x_desc
,
out_n_k_ho_wo_desc
)
/
in_n_c_hi_wi_desc
,
wei_k_c_y_x_desc
,
out_n_k_ho_wo_desc
)
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
...
...
host/driver_offline/include/device_
dynamic_
convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
→
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
ccc4a1d3
#include <unistd.h>
#include <unistd.h>
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "driver_
dynamic_
convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
#include "driver_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
template
<
typename
TInWei
,
template
<
typename
TInWei
,
typename
TAcc
,
typename
TAcc
,
...
@@ -13,7 +13,7 @@ template <typename TInWei,
...
@@ -13,7 +13,7 @@ template <typename TInWei,
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
>
typename
InRightPads
>
void
device_
dynamic_
convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
(
void
device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
(
const
InLengths
&
in_n_c_hi_wi_lengths
,
const
InLengths
&
in_n_c_hi_wi_lengths
,
const
WeiLengths
&
wei_k_c_y_x_lengths
,
const
WeiLengths
&
wei_k_c_y_x_lengths
,
const
OutLengths
&
out_n_k_ho_wo_lengths
,
const
OutLengths
&
out_n_k_ho_wo_lengths
,
...
@@ -48,12 +48,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
...
@@ -48,12 +48,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
wei_k_c_y_x_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
wei_k_c_y_x_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
out_n_k_ho_wo_device_buf
.
ToDevice
(
out_n_k_ho_wo
.
mData
.
data
());
out_n_k_ho_wo_device_buf
.
ToDevice
(
out_n_k_ho_wo
.
mData
.
data
());
const
auto
in_n_c_hi_wi_desc
=
const
auto
in_n_c_hi_wi_desc
=
make_naive_tensor_descriptor_packed
(
in_n_c_hi_wi_lengths
);
make_dynamic_naive_tensor_descriptor_packed_v2
(
in_n_c_hi_wi_lengths
);
const
auto
wei_k_c_y_x_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_c_y_x_lengths
);
const
auto
wei_k_c_y_x_desc
=
const
auto
out_n_k_ho_wo_desc
=
make_naive_tensor_descriptor_packed
(
out_n_k_ho_wo_lengths
);
make_dynamic_naive_tensor_descriptor_packed_v2
(
wei_k_c_y_x_lengths
);
const
auto
out_n_k_ho_wo_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
out_n_k_ho_wo_lengths
);
#if 0
#if 0
constexpr index_t BlockSize = 256;
constexpr index_t BlockSize = 256;
...
@@ -212,9 +209,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
...
@@ -212,9 +209,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
{
#if 0
#if 0
float ave_time = launch_kernel_
dynamic_
gemm_xdlops_v1
float ave_time = launch_kernel_gemm_xdlops_v1
#else
#else
float
ave_time
=
launch_kernel_
dynamic_
gemm_xdlops_v2
float
ave_time
=
launch_kernel_gemm_xdlops_v2
#endif
#endif
<
BlockSize
,
<
BlockSize
,
TInWei
,
TInWei
,
...
...
host/driver_offline/include/device_
dynamic_
convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp
→
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp
View file @
ccc4a1d3
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
#include "driver_
dynamic_
gemm_dlops_v1r3.hpp"
#include "driver_gemm_dlops_v1r3.hpp"
template
<
typename
TInWei
,
template
<
typename
TInWei
,
typename
TAcc
,
typename
TAcc
,
...
@@ -14,7 +14,7 @@ template <typename TInWei,
...
@@ -14,7 +14,7 @@ template <typename TInWei,
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
>
typename
InRightPads
>
void
device_
dynamic_
convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk
(
void
device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk
(
const
InLengths
&
in_n_hi_wi_c_lengths
,
const
InLengths
&
in_n_hi_wi_c_lengths
,
const
WeiLengths
&
wei_k_y_x_c_lengths
,
const
WeiLengths
&
wei_k_y_x_c_lengths
,
const
OutLengths
&
out_n_ho_wo_k_lengths
,
const
OutLengths
&
out_n_ho_wo_k_lengths
,
...
@@ -35,11 +35,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
...
@@ -35,11 +35,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
auto
I8
=
Number
<
8
>
{};
DeviceMem
in_n_hi_wi_c_device_buf
(
sizeof
(
TInWei
)
*
in_n_hi_wi_c
.
mDesc
.
GetElementSpace
());
DeviceMem
in_n_hi_wi_c_device_buf
(
sizeof
(
TInWei
)
*
in_n_hi_wi_c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_y_x_c_device_buf
(
sizeof
(
TInWei
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_y_x_c_device_buf
(
sizeof
(
TInWei
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
());
...
@@ -49,14 +44,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
...
@@ -49,14 +44,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
const
auto
in_n_hi_wi_c_desc
=
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor_packed
(
in_n_hi_wi_c_lengths
);
make_dynamic_naive_tensor_descriptor_packed_v2
(
in_n_hi_wi_c_lengths
);
const
auto
wei_k_y_x_c_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_y_x_c_lengths
);
const
auto
wei_k_y_x_c_desc
=
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
make_dynamic_naive_tensor_descriptor_packed_v2
(
wei_k_y_x_c_lengths
);
const
auto
out_n_ho_wo_k_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
out_n_ho_wo_k_lengths
);
#if
1
#if
0
// [M, N, K0, K1] = [128, 128, 8, 1] for fp32
// [M, N, K0, K1] = [128, 128, 8, 1] for fp32
// cdata = 64, BlockSize = 256
// cdata = 64, BlockSize = 256
constexpr index_t BlockSize = 256;
constexpr index_t BlockSize = 256;
...
@@ -163,7 +155,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
...
@@ -163,7 +155,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
const
auto
out_gemmm_gemmn_grid_desc
=
descs
[
I2
];
const
auto
out_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_
i
te
rator
_hacks
=
make_tuple
(
constexpr
auto
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmM0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmM0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmM1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmM1
...
@@ -173,7 +165,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
...
@@ -173,7 +165,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: GemmM1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: GemmM1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
constexpr
auto
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_
i
te
rator
_hacks
=
constexpr
auto
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmN0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmN0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmN1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmN1
...
@@ -183,7 +175,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
...
@@ -183,7 +175,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: GemmN1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: GemmN1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
constexpr
auto
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_
i
te
rator
_hacks
=
constexpr
auto
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmM0
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmM0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmM10
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmM10
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmM11
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmM11
...
@@ -197,15 +189,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
...
@@ -197,15 +189,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 4-: GemmN10
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 4-: GemmN10
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
// 5-: GemmN11
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
// 5-: GemmN11
constexpr
auto
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
,
0
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
{
float
ave_time
=
driver_
dynamic_
gemm_dlops_v1r3
<
float
ave_time
=
driver_gemm_dlops_v1r3
<
BlockSize
,
BlockSize
,
TInWei
,
TInWei
,
TAcc
,
TAcc
,
...
@@ -239,22 +231,22 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
...
@@ -239,22 +231,22 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// CThreadTransferSrcDstAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// CThreadTransferSrcDstAccessOrder
5
,
// CThreadTransferSrcDstVectorDim
5
,
// CThreadTransferSrcDstVectorDim
GemmCThreadTransferDstScalarPerVector_N11
,
GemmCThreadTransferDstScalarPerVector_N11
,
decltype
(
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_
i
te
rator
_hacks
),
decltype
(
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_
s
te
p
_hacks
),
decltype
(
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_
i
te
rator
_hacks
),
decltype
(
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_
s
te
p
_hacks
),
decltype
(
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_
i
te
rator
_hacks
),
decltype
(
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_
s
te
p
_hacks
),
decltype
(
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
),
decltype
(
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_
s
te
p
_hacks
),
decltype
(
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
)
>
(
decltype
(
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_
s
te
p
_hacks
)
>
(
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
in_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
out_gemmm_gemmn_grid_desc
,
out_gemmm_gemmn_grid_desc
,
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_
i
te
rator
_hacks
,
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_
s
te
p
_hacks
,
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_
i
te
rator
_hacks
,
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_
s
te
p
_hacks
,
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_
i
te
rator
_hacks
,
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_
s
te
p
_hacks
,
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
,
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_
s
te
p
_hacks
,
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
,
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_
s
te
p
_hacks
,
nrepeat
);
nrepeat
);
{
{
...
@@ -262,16 +254,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
...
@@ -262,16 +254,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
const
auto
K
=
out_n_ho_wo_k_lengths
[
I3
];
const
auto
K
=
out_n_ho_wo_k_lengths
[
I3
];
const
auto
C
=
wei_k_y_x_c_lengths
[
I3
];
const
auto
C
=
wei_k_y_x_c_lengths
[
I3
];
const
auto
Hi
=
in_n_hi_wi_c_lengths
[
I1
];
const
auto
Wi
=
in_n_hi_wi_c_lengths
[
I2
];
const
auto
Ho
=
out_n_ho_wo_k_lengths
[
I1
];
const
auto
Ho
=
out_n_ho_wo_k_lengths
[
I1
];
const
auto
Wo
=
out_n_ho_wo_k_lengths
[
I2
];
const
auto
Wo
=
out_n_ho_wo_k_lengths
[
I2
];
const
auto
Y
=
wei_k_y_x_c_lengths
[
I1
];
const
auto
Y
=
wei_k_y_x_c_lengths
[
I1
];
const
auto
X
=
wei_k_y_x_c_lengths
[
I2
];
const
auto
X
=
wei_k_y_x_c_lengths
[
I2
];
float
perf
=
(
float
)
(
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
)
/
float
perf
=
static_cast
<
float
>
(
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
...
...
host/driver_offline/include/device_
dynamic_
convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
→
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
View file @
ccc4a1d3
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
#include "driver_
dynamic_
gemm_xdlops_v2r3.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
TInWei
,
template
<
typename
TInWei
,
typename
TAcc
,
typename
TAcc
,
...
@@ -14,7 +14,7 @@ template <typename TInWei,
...
@@ -14,7 +14,7 @@ template <typename TInWei,
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
>
typename
InRightPads
>
void
device_
dynamic_
convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
(
void
device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
(
const
InLengths
&
in_n_c_hi_wi_lengths
,
const
InLengths
&
in_n_c_hi_wi_lengths
,
const
WeiLengths
&
wei_k_c_y_x_lengths
,
const
WeiLengths
&
wei_k_c_y_x_lengths
,
const
OutLengths
&
out_n_k_ho_wo_lengths
,
const
OutLengths
&
out_n_k_ho_wo_lengths
,
...
@@ -34,12 +34,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
...
@@ -34,12 +34,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
auto
I8
=
Number
<
8
>
{};
DeviceMem
in_n_c_hi_wi_device_buf
(
sizeof
(
TInWei
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
in_n_c_hi_wi_device_buf
(
sizeof
(
TInWei
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c_y_x_device_buf
(
sizeof
(
TInWei
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c_y_x_device_buf
(
sizeof
(
TInWei
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
...
@@ -49,12 +43,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
...
@@ -49,12 +43,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
wei_k_c_y_x_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
wei_k_c_y_x_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
out_n_k_ho_wo_device_buf
.
ToDevice
(
out_n_k_ho_wo
.
mData
.
data
());
out_n_k_ho_wo_device_buf
.
ToDevice
(
out_n_k_ho_wo
.
mData
.
data
());
const
auto
in_n_c_hi_wi_desc
=
const
auto
in_n_c_hi_wi_desc
=
make_naive_tensor_descriptor_packed
(
in_n_c_hi_wi_lengths
);
make_dynamic_naive_tensor_descriptor_packed_v2
(
in_n_c_hi_wi_lengths
);
const
auto
wei_k_c_y_x_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_c_y_x_lengths
);
const
auto
wei_k_c_y_x_desc
=
const
auto
out_n_k_ho_wo_desc
=
make_naive_tensor_descriptor_packed
(
out_n_k_ho_wo_lengths
);
make_dynamic_naive_tensor_descriptor_packed_v2
(
wei_k_c_y_x_lengths
);
const
auto
out_n_k_ho_wo_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
out_n_k_ho_wo_lengths
);
#if 1
#if 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
...
@@ -101,12 +92,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
...
@@ -101,12 +92,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
const
auto
out_gemmm_gemmn_grid_desc
=
descs
[
I2
];
const
auto
out_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
=
make_tuple
(
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
=
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
...
@@ -114,7 +105,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
...
@@ -114,7 +105,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{}));
constexpr
auto
out_m0_m1_m2_n_grid_
i
te
rator
_hacks
=
constexpr
auto
out_m0_m1_m2_n_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
...
@@ -132,15 +123,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
...
@@ -132,15 +123,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
{
float
ave_time
=
driver_
dynamic_
gemm_xdlops_v2r3
<
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
BlockSize
,
TInWei
,
TInWei
,
TAcc
,
TAcc
,
...
@@ -176,26 +167,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
...
@@ -176,26 +167,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
Sequence
<
3
,
0
,
1
,
2
,
7
,
5
,
4
,
6
>
,
Sequence
<
3
,
0
,
1
,
2
,
7
,
5
,
4
,
6
>
,
7
,
7
,
GemmCThreadTransferDstScalarPerVector
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
),
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
),
decltype
(
out_m0_m1_m2_n_grid_
i
te
rator
_hacks
),
decltype
(
out_m0_m1_m2_n_grid_
s
te
p
_hacks
),
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
),
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
),
false
>
(
static_cast
<
TInWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
false
>
(
static_cast
<
TInWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
wei_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
out_gemmm_gemmn_grid_desc
,
out_gemmm_gemmn_grid_desc
,
wei_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
,
wei_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
,
in_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
,
in_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
,
out_m0_m1_m2_n_grid_
i
te
rator
_hacks
,
out_m0_m1_m2_n_grid_
s
te
p
_hacks
,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
,
nrepeat
);
nrepeat
);
float
perf
=
(
float
)
calculate_convolution_flops
(
float
perf
=
static_cast
<
float
>
(
calculate_convolution_flops
(
in_n_c_hi_wi_desc
,
wei_k_c_y_x_desc
,
out_n_k_ho_wo_desc
)
/
in_n_c_hi_wi_desc
,
wei_k_c_y_x_desc
,
out_n_k_ho_wo_desc
)
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
...
...
host/driver_offline/include/device_
dynamic_
convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp
→
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp
View file @
ccc4a1d3
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
#include "driver_
dynamic_
gemm_xdlops_v2r2.hpp"
#include "driver_gemm_xdlops_v2r2.hpp"
template
<
typename
TInWei
,
template
<
typename
TInWei
,
typename
TAcc
,
typename
TAcc
,
...
@@ -14,7 +14,7 @@ template <typename TInWei,
...
@@ -14,7 +14,7 @@ template <typename TInWei,
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
>
typename
InRightPads
>
void
device_
dynamic_
convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk
(
void
device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk
(
const
InLengths
&
in_n_hi_wi_c_lengths
,
const
InLengths
&
in_n_hi_wi_c_lengths
,
const
WeiLengths
&
wei_k_y_x_c_lengths
,
const
WeiLengths
&
wei_k_y_x_c_lengths
,
const
OutLengths
&
out_n_ho_wo_k_lengths
,
const
OutLengths
&
out_n_ho_wo_k_lengths
,
...
@@ -35,11 +35,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
...
@@ -35,11 +35,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
auto
I8
=
Number
<
8
>
{};
DeviceMem
in_n_hi_wi_c_device_buf
(
sizeof
(
TInWei
)
*
in_n_hi_wi_c
.
mDesc
.
GetElementSpace
());
DeviceMem
in_n_hi_wi_c_device_buf
(
sizeof
(
TInWei
)
*
in_n_hi_wi_c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_y_x_c_device_buf
(
sizeof
(
TInWei
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_y_x_c_device_buf
(
sizeof
(
TInWei
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
());
...
@@ -49,12 +44,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
...
@@ -49,12 +44,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
const
auto
in_n_hi_wi_c_desc
=
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor_packed
(
in_n_hi_wi_c_lengths
);
make_dynamic_naive_tensor_descriptor_packed_v2
(
in_n_hi_wi_c_lengths
);
const
auto
wei_k_y_x_c_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_y_x_c_lengths
);
const
auto
wei_k_y_x_c_desc
=
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
make_dynamic_naive_tensor_descriptor_packed_v2
(
wei_k_y_x_c_lengths
);
const
auto
out_n_ho_wo_k_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
out_n_ho_wo_k_lengths
);
#if 1
#if 1
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
...
@@ -129,12 +121,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
...
@@ -129,12 +121,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
const
auto
out_gemmm_gemmn_grid_desc
=
descs
[
I2
];
const
auto
out_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
=
make_tuple
(
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
=
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
...
@@ -142,7 +134,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
...
@@ -142,7 +134,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{}));
constexpr
auto
out_m0_m1_m2_n_grid_
i
te
rator
_hacks
=
constexpr
auto
out_m0_m1_m2_n_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
...
@@ -152,15 +144,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
...
@@ -152,15 +144,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
{
float
ave_time
=
driver_
dynamic_
gemm_xdlops_v2r2
<
float
ave_time
=
driver_gemm_xdlops_v2r2
<
BlockSize
,
BlockSize
,
TInWei
,
TInWei
,
TAcc
,
TAcc
,
...
@@ -195,22 +187,22 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
...
@@ -195,22 +187,22 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
Sequence
<
2
,
3
,
0
,
1
>
,
Sequence
<
2
,
3
,
0
,
1
>
,
2
,
2
,
GemmCThreadTransferDstScalarPerVector
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
),
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
),
decltype
(
out_m0_m1_m2_n_grid_
i
te
rator
_hacks
),
decltype
(
out_m0_m1_m2_n_grid_
s
te
p
_hacks
),
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
),
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
)
>
(
decltype
(
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
)
>
(
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
wei_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
out_gemmm_gemmn_grid_desc
,
out_gemmm_gemmn_grid_desc
,
wei_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
,
wei_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
,
in_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
,
in_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
,
out_m0_m1_m2_n_grid_
i
te
rator
_hacks
,
out_m0_m1_m2_n_grid_
s
te
p
_hacks
,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
,
nrepeat
);
nrepeat
);
{
{
...
@@ -218,9 +210,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
...
@@ -218,9 +210,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
const
auto
K
=
out_n_ho_wo_k_lengths
[
I3
];
const
auto
K
=
out_n_ho_wo_k_lengths
[
I3
];
const
auto
C
=
wei_k_y_x_c_lengths
[
I3
];
const
auto
C
=
wei_k_y_x_c_lengths
[
I3
];
const
auto
Hi
=
in_n_hi_wi_c_lengths
[
I1
];
const
auto
Wi
=
in_n_hi_wi_c_lengths
[
I2
];
const
auto
Ho
=
out_n_ho_wo_k_lengths
[
I1
];
const
auto
Ho
=
out_n_ho_wo_k_lengths
[
I1
];
const
auto
Wo
=
out_n_ho_wo_k_lengths
[
I2
];
const
auto
Wo
=
out_n_ho_wo_k_lengths
[
I2
];
...
...
host/driver_offline/include/device_
dynamic_
convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp
→
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp
View file @
ccc4a1d3
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
#include "driver_
dynamic_
gemm_xdlops_v2r3.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
TInWei
,
template
<
typename
TInWei
,
typename
TAcc
,
typename
TAcc
,
...
@@ -14,7 +14,7 @@ template <typename TInWei,
...
@@ -14,7 +14,7 @@ template <typename TInWei,
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
>
typename
InRightPads
>
void
device_
dynamic_
convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk
(
void
device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk
(
const
InLengths
&
in_n_hi_wi_c_lengths
,
const
InLengths
&
in_n_hi_wi_c_lengths
,
const
WeiLengths
&
wei_k_y_x_c_lengths
,
const
WeiLengths
&
wei_k_y_x_c_lengths
,
const
OutLengths
&
out_n_ho_wo_k_lengths
,
const
OutLengths
&
out_n_ho_wo_k_lengths
,
...
@@ -49,12 +49,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
...
@@ -49,12 +49,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
const
auto
in_n_hi_wi_c_desc
=
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor_packed
(
in_n_hi_wi_c_lengths
);
make_dynamic_naive_tensor_descriptor_packed_v2
(
in_n_hi_wi_c_lengths
);
const
auto
wei_k_y_x_c_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_y_x_c_lengths
);
const
auto
wei_k_y_x_c_desc
=
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
make_dynamic_naive_tensor_descriptor_packed_v2
(
wei_k_y_x_c_lengths
);
const
auto
out_n_ho_wo_k_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
out_n_ho_wo_k_lengths
);
#if 1
#if 1
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
...
@@ -185,12 +182,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
...
@@ -185,12 +182,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
const
auto
out_gemmm_gemmn_grid_desc
=
descs
[
I2
];
const
auto
out_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
=
make_tuple
(
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
=
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
...
@@ -198,7 +195,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
...
@@ -198,7 +195,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{}));
constexpr
auto
out_m0_m1_m2_n_grid_
i
te
rator
_hacks
=
constexpr
auto
out_m0_m1_m2_n_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
...
@@ -216,15 +213,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
...
@@ -216,15 +213,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
{
float
ave_time
=
driver_
dynamic_
gemm_xdlops_v2r3
<
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
BlockSize
,
TInWei
,
TInWei
,
TAcc
,
TAcc
,
...
@@ -259,11 +256,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
...
@@ -259,11 +256,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
6
,
6
,
GemmCThreadTransferDstScalarPerVector
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
),
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
),
decltype
(
out_m0_m1_m2_n_grid_
i
te
rator
_hacks
),
decltype
(
out_m0_m1_m2_n_grid_
s
te
p
_hacks
),
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
),
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
),
false
// CAccessOrderMRepeatNRepeat
false
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
>
(
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
...
@@ -271,11 +268,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
...
@@ -271,11 +268,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
wei_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
out_gemmm_gemmn_grid_desc
,
out_gemmm_gemmn_grid_desc
,
wei_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
,
wei_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
,
in_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
,
in_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
,
out_m0_m1_m2_n_grid_
i
te
rator
_hacks
,
out_m0_m1_m2_n_grid_
s
te
p
_hacks
,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
,
nrepeat
);
nrepeat
);
{
{
...
...
host/driver_offline/include/device_
dynamic_
convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
→
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
View file @
ccc4a1d3
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
#include "driver_
dynamic_
gemm_xdlops_v2r3.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
TInWei
,
template
<
typename
TInWei
,
typename
TAcc
,
typename
TAcc
,
...
@@ -14,7 +14,7 @@ template <typename TInWei,
...
@@ -14,7 +14,7 @@ template <typename TInWei,
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
>
typename
InRightPads
>
void
device_
dynamic_
convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
(
void
device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
(
const
InLengths
&
in_n_hi_wi_c_lengths
,
const
InLengths
&
in_n_hi_wi_c_lengths
,
const
WeiLengths
&
wei_k_y_x_c_lengths
,
const
WeiLengths
&
wei_k_y_x_c_lengths
,
const
OutLengths
&
out_n_ho_wo_k_lengths
,
const
OutLengths
&
out_n_ho_wo_k_lengths
,
...
@@ -35,11 +35,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
...
@@ -35,11 +35,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
auto
I8
=
Number
<
8
>
{};
DeviceMem
in_n_hi_wi_c_device_buf
(
sizeof
(
TInWei
)
*
in_n_hi_wi_c
.
mDesc
.
GetElementSpace
());
DeviceMem
in_n_hi_wi_c_device_buf
(
sizeof
(
TInWei
)
*
in_n_hi_wi_c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_y_x_c_device_buf
(
sizeof
(
TInWei
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_y_x_c_device_buf
(
sizeof
(
TInWei
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
());
...
@@ -49,12 +44,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
...
@@ -49,12 +44,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
const
auto
in_n_hi_wi_c_desc
=
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor_packed
(
in_n_hi_wi_c_lengths
);
make_dynamic_naive_tensor_descriptor_packed_v2
(
in_n_hi_wi_c_lengths
);
const
auto
wei_k_y_x_c_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_y_x_c_lengths
);
const
auto
wei_k_y_x_c_desc
=
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
make_dynamic_naive_tensor_descriptor_packed_v2
(
wei_k_y_x_c_lengths
);
const
auto
out_n_ho_wo_k_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
out_n_ho_wo_k_lengths
);
#if 0
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
...
@@ -241,7 +233,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
...
@@ -241,7 +233,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
const
auto
out_gemmm_gemmn_grid_desc
=
descs
[
I2
];
const
auto
out_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
in_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
=
constexpr
auto
in_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
// 0+: GemmK0
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 1+: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 1+: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
// 2+: GemmK1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
// 2+: GemmK1
...
@@ -249,7 +241,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
...
@@ -249,7 +241,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
// 1-: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
// 1-: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
wei_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
=
constexpr
auto
wei_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: GemmK1
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: GemmK1
...
@@ -257,7 +249,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
...
@@ -257,7 +249,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
out_m0_m1_m2_n_grid_
i
te
rator
_hacks
=
constexpr
auto
out_m0_m1_m2_n_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: MRepeat
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: MRepeat
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1+: NRepeat
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1+: NRepeat
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 2+: MWaves
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 2+: MWaves
...
@@ -275,15 +267,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
...
@@ -275,15 +267,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N1
constexpr
auto
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
{};
constexpr
auto
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
{
float
ave_time
=
driver_
dynamic_
gemm_xdlops_v2r3
<
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
BlockSize
,
TInWei
,
TInWei
,
TAcc
,
TAcc
,
...
@@ -319,11 +311,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
...
@@ -319,11 +311,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
7
,
7
,
GemmCThreadTransferDstScalarPerVector
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
in_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
),
decltype
(
in_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
),
decltype
(
out_m0_m1_m2_n_grid_
i
te
rator
_hacks
),
decltype
(
out_m0_m1_m2_n_grid_
s
te
p
_hacks
),
decltype
(
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
),
decltype
(
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
),
false
// CAccessOrderMRepeatNRepeat
false
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
>
(
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
...
@@ -331,11 +323,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
...
@@ -331,11 +323,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
in_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
out_gemmm_gemmn_grid_desc
,
out_gemmm_gemmn_grid_desc
,
in_gemmk0_gemmm_gemmk1_grid_
i
te
rator
_hacks
,
in_gemmk0_gemmm_gemmk1_grid_
s
te
p
_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_
i
te
rator
_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_
s
te
p
_hacks
,
out_m0_m1_m2_n_grid_
i
te
rator
_hacks
,
out_m0_m1_m2_n_grid_
s
te
p
_hacks
,
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
,
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_
s
te
p
_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_
i
te
rator
_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_
s
te
p
_hacks
,
nrepeat
);
nrepeat
);
{
{
...
@@ -343,16 +335,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
...
@@ -343,16 +335,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
const
auto
K
=
out_n_ho_wo_k_lengths
[
I3
];
const
auto
K
=
out_n_ho_wo_k_lengths
[
I3
];
const
auto
C
=
wei_k_y_x_c_lengths
[
I3
];
const
auto
C
=
wei_k_y_x_c_lengths
[
I3
];
const
auto
Hi
=
in_n_hi_wi_c_lengths
[
I1
];
const
auto
Wi
=
in_n_hi_wi_c_lengths
[
I2
];
const
auto
Ho
=
out_n_ho_wo_k_lengths
[
I1
];
const
auto
Ho
=
out_n_ho_wo_k_lengths
[
I1
];
const
auto
Wo
=
out_n_ho_wo_k_lengths
[
I2
];
const
auto
Wo
=
out_n_ho_wo_k_lengths
[
I2
];
const
auto
Y
=
wei_k_y_x_c_lengths
[
I1
];
const
auto
Y
=
wei_k_y_x_c_lengths
[
I1
];
const
auto
X
=
wei_k_y_x_c_lengths
[
I2
];
const
auto
X
=
wei_k_y_x_c_lengths
[
I2
];
float
perf
=
(
float
)
(
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
)
/
float
perf
=
static_cast
<
float
>
(
(
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
)
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
...
...
host/driver_offline/include/device_
dynamic_
convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
→
host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
View file @
ccc4a1d3
#include <unistd.h>
#include <unistd.h>
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "driver_
dynamic_
convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp"
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp"
#include "driver_
dynamic_
convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp"
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp"
template
<
typename
TInWei
,
template
<
typename
TInWei
,
ck
::
index_t
InWeiVectorSize
,
ck
::
index_t
InWeiVectorSize
,
...
@@ -15,7 +15,7 @@ template <typename TInWei,
...
@@ -15,7 +15,7 @@ template <typename TInWei,
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
>
typename
InRightPads
>
void
device_
dynamic_
convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw
(
void
device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw
(
const
InLengths
&
in_n_c_hi_wi_lengths
,
const
InLengths
&
in_n_c_hi_wi_lengths
,
const
WeiLengths
&
wei_k_c_y_x_lengths
,
const
WeiLengths
&
wei_k_c_y_x_lengths
,
const
OutLengths
&
out_n_k_ho_wo_lengths
,
const
OutLengths
&
out_n_k_ho_wo_lengths
,
...
@@ -26,7 +26,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
...
@@ -26,7 +26,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
const
Tensor
<
TInWei
>&
in_n_c_hi_wi
,
const
Tensor
<
TInWei
>&
in_n_c_hi_wi
,
const
Tensor
<
TInWei
>&
wei_k_c_y_x
,
const
Tensor
<
TInWei
>&
wei_k_c_y_x
,
Tensor
<
TOut
>&
out_n_k_ho_wo
,
Tensor
<
TOut
>&
out_n_k_ho_wo
,
ck
::
index_t
nrepeat
)
ck
::
index_t
/*
nrepeat
*/
)
{
{
using
namespace
ck
;
using
namespace
ck
;
...
@@ -85,12 +85,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
...
@@ -85,12 +85,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
in_n_c0_hi_wi_c1_device_buf
.
ToDevice
(
in_n_c0_hi_wi_c1
.
mData
.
data
());
in_n_c0_hi_wi_c1_device_buf
.
ToDevice
(
in_n_c0_hi_wi_c1
.
mData
.
data
());
wei_k_c0_y_x_c1_device_buf
.
ToDevice
(
wei_k_c0_y_x_c1
.
mData
.
data
());
wei_k_c0_y_x_c1_device_buf
.
ToDevice
(
wei_k_c0_y_x_c1
.
mData
.
data
());
const
auto
in_n_c0_hi_wi_desc
=
const
auto
in_n_c0_hi_wi_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
C0
,
Hi
,
Wi
));
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
C0
,
Hi
,
Wi
));
const
auto
wei_k_c0_y_x_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C0
,
Y
,
X
));
const
auto
wei_k_c0_y_x_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
K
,
C0
,
Y
,
X
));
const
auto
out_n_k0_ho_wo_k1_desc
=
const
auto
out_n_k0_ho_wo_k1_desc
=
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
));
#if 1
#if 1
// cdata = 64, BlockSize = 64, 16x8x32x4
// cdata = 64, BlockSize = 64, 16x8x32x4
...
...
host/driver_offline/include/device_
dynamic_
convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp
→
host/driver_offline/include/device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp
View file @
ccc4a1d3
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
#include "driver_
dynamic_
contraction_dlops_v1r2.hpp"
#include "driver_contraction_dlops_v1r2.hpp"
template
<
typename
TInWei
,
template
<
typename
TInWei
,
typename
TAcc
,
typename
TAcc
,
...
@@ -15,7 +15,7 @@ template <typename TInWei,
...
@@ -15,7 +15,7 @@ template <typename TInWei,
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
>
typename
InRightPads
>
void
device_
dynamic_
convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw
(
void
device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw
(
const
InLengths
&
in_n_c_hi_wi_lengths
,
const
InLengths
&
in_n_c_hi_wi_lengths
,
const
WeiLengths
&
wei_k_c_y_x_lengths
,
const
WeiLengths
&
wei_k_c_y_x_lengths
,
const
OutLengths
&
out_n_k_ho_wo_lengths
,
const
OutLengths
&
out_n_k_ho_wo_lengths
,
...
@@ -44,12 +44,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
...
@@ -44,12 +44,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
wei_k_c_y_x_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
wei_k_c_y_x_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
out_n_k_ho_wo_device_buf
.
ToDevice
(
out_n_k_ho_wo
.
mData
.
data
());
out_n_k_ho_wo_device_buf
.
ToDevice
(
out_n_k_ho_wo
.
mData
.
data
());
const
auto
in_desc_n_c_hi_wi
=
const
auto
in_desc_n_c_hi_wi
=
make_naive_tensor_descriptor_packed
(
in_n_c_hi_wi_lengths
);
make_dynamic_naive_tensor_descriptor_packed_v2
(
in_n_c_hi_wi_lengths
);
const
auto
wei_desc_k_c_y_x
=
make_naive_tensor_descriptor_packed
(
wei_k_c_y_x_lengths
);
const
auto
wei_desc_k_c_y_x
=
const
auto
out_desc_n_k_ho_wo
=
make_naive_tensor_descriptor_packed
(
out_n_k_ho_wo_lengths
);
make_dynamic_naive_tensor_descriptor_packed_v2
(
wei_k_c_y_x_lengths
);
const
auto
out_desc_n_k_ho_wo
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
out_n_k_ho_wo_lengths
);
#if 1
#if 1
// [8, 1, 128, 1] * [8, 4, 32, 1] = [1, 128, 4, 32] for fp32
// [8, 1, 128, 1] * [8, 4, 32, 1] = [1, 128, 4, 32] for fp32
...
@@ -133,7 +130,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
...
@@ -133,7 +130,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
const
auto
out_grid_desc_gm0_gm1_gn0_gn1
=
descs
[
I2
];
const
auto
out_grid_desc_gm0_gm1_gn0_gn1
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
wei_grid_
i
te
rator
_hacks
=
constexpr
auto
wei_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GK0
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GM0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GM0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GM10
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GM10
...
@@ -145,7 +142,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
...
@@ -145,7 +142,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: GM11
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: GM11
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 4-: GK1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 4-: GK1
constexpr
auto
in_grid_
i
te
rator
_hacks
=
make_tuple
(
constexpr
auto
in_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GK0
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GN0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GN0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GN10
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GN10
...
@@ -157,7 +154,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
...
@@ -157,7 +154,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: GN11
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: GN11
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 4-: GK1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 4-: GK1
constexpr
auto
out_grid_
i
te
rator
_hacks
=
make_tuple
(
constexpr
auto
out_grid_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GM10
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GM10
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: BM0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: BM0
...
@@ -173,14 +170,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
...
@@ -173,14 +170,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 4-: BN0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 4-: BN0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{}));
// 5-: GN1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{}));
// 5-: GN1
constexpr
auto
wei_grid_move_slice_window_
i
te
rator
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
wei_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
in_grid_move_slice_window_
i
te
rator
_hacks
=
constexpr
auto
in_grid_move_slice_window_
s
te
p
_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
{
float
ave_time
=
driver_
dynamic_
contraction_dlops_v1r2
<
float
ave_time
=
driver_contraction_dlops_v1r2
<
BlockSize
,
BlockSize
,
TInWei
,
TInWei
,
TAcc
,
TAcc
,
...
@@ -214,26 +211,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
...
@@ -214,26 +211,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
Sequence
<
3
,
4
,
5
,
0
,
1
,
2
>
,
// CThreadTransferSrcDstAccessOrder
Sequence
<
3
,
4
,
5
,
0
,
1
,
2
>
,
// CThreadTransferSrcDstAccessOrder
5
,
// CThreadTransferSrcDstVectorDim
5
,
// CThreadTransferSrcDstVectorDim
CThreadTransferDstScalarPerVector_BN1
,
CThreadTransferDstScalarPerVector_BN1
,
decltype
(
wei_grid_
i
te
rator
_hacks
),
decltype
(
wei_grid_
s
te
p
_hacks
),
decltype
(
in_grid_
i
te
rator
_hacks
),
decltype
(
in_grid_
s
te
p
_hacks
),
decltype
(
out_grid_
i
te
rator
_hacks
),
decltype
(
out_grid_
s
te
p
_hacks
),
decltype
(
wei_grid_move_slice_window_
i
te
rator
_hacks
),
decltype
(
wei_grid_move_slice_window_
s
te
p
_hacks
),
decltype
(
in_grid_move_slice_window_
i
te
rator
_hacks
)
>
(
decltype
(
in_grid_move_slice_window_
s
te
p
_hacks
)
>
(
static_cast
<
TInWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
wei_grid_desc_gk0_gm0_gm1_gk1
,
wei_grid_desc_gk0_gm0_gm1_gk1
,
in_grid_desc_gk0_gn0_gn1_gk1
,
in_grid_desc_gk0_gn0_gn1_gk1
,
out_grid_desc_gm0_gm1_gn0_gn1
,
out_grid_desc_gm0_gm1_gn0_gn1
,
wei_grid_
i
te
rator
_hacks
,
wei_grid_
s
te
p
_hacks
,
in_grid_
i
te
rator
_hacks
,
in_grid_
s
te
p
_hacks
,
out_grid_
i
te
rator
_hacks
,
out_grid_
s
te
p
_hacks
,
wei_grid_move_slice_window_
i
te
rator
_hacks
,
wei_grid_move_slice_window_
s
te
p
_hacks
,
in_grid_move_slice_window_
i
te
rator
_hacks
,
in_grid_move_slice_window_
s
te
p
_hacks
,
nrepeat
);
nrepeat
);
float
perf
=
(
float
)
calculate_convolution_flops
(
float
perf
=
static_cast
<
float
>
(
calculate_convolution_flops
(
in_desc_n_c_hi_wi
,
wei_desc_k_c_y_x
,
out_desc_n_k_ho_wo
)
/
in_desc_n_c_hi_wi
,
wei_desc_k_c_y_x
,
out_desc_n_k_ho_wo
)
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
...
...
host/driver_offline/include/driver_
dynamic_
contraction_dlops_v1r2.hpp
→
host/driver_offline/include/driver_contraction_dlops_v1r2.hpp
View file @
ccc4a1d3
#ifndef DRIVER_
DYNAMIC_
CONTRACTION_DLOPS_V1R2_HPP
#ifndef DRIVER_CONTRACTION_DLOPS_V1R2_HPP
#define DRIVER_
DYNAMIC_
CONTRACTION_DLOPS_V1R2_HPP
#define DRIVER_CONTRACTION_DLOPS_V1R2_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_
dynamic_
contraction_dlops_v1r2.hpp"
#include "gridwise_contraction_dlops_v1r2.hpp"
template
<
ck
::
index_t
BlockSize
,
template
<
ck
::
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAB
,
...
@@ -39,24 +39,24 @@ template <ck::index_t BlockSize,
...
@@ -39,24 +39,24 @@ template <ck::index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferDstScalarPerVector
,
ck
::
index_t
CThreadTransferDstScalarPerVector
,
typename
AGrid
I
te
rator
Hacks
,
typename
AGrid
S
te
p
Hacks
,
typename
BGrid
I
te
rator
Hacks
,
typename
BGrid
S
te
p
Hacks
,
typename
CGrid
I
te
rator
Hacks
,
typename
CGrid
S
te
p
Hacks
,
typename
AGridMoveSliceWindow
I
te
rator
Hacks
,
typename
AGridMoveSliceWindow
S
te
p
Hacks
,
typename
BGridMoveSliceWindow
I
te
rator
Hacks
>
typename
BGridMoveSliceWindow
S
te
p
Hacks
>
__host__
float
__host__
float
driver_
dynamic_
contraction_dlops_v1r2
(
const
FloatAB
*
p_a_grid
,
driver_contraction_dlops_v1r2
(
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_grid
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
FloatC
*
p_c_grid
,
const
AGridDesc_GK0_GM0_GM1_GK1
&
a_grid_desc_gk0_gm0_gm1_gk1
,
const
AGridDesc_GK0_GM0_GM1_GK1
&
a_grid_desc_gk0_gm0_gm1_gk1
,
const
BGridDesc_GK0_GN0_GN1_GK1
&
b_grid_desc_gk0_gn0_gn1_gk1
,
const
BGridDesc_GK0_GN0_GN1_GK1
&
b_grid_desc_gk0_gn0_gn1_gk1
,
const
CGridDesc_GM0_GM1_GN0_GN1
&
c_grid_desc_gm0_gm1_gn0_gn1
,
const
CGridDesc_GM0_GM1_GN0_GN1
&
c_grid_desc_gm0_gm1_gn0_gn1
,
AGrid
I
te
rator
Hacks
,
AGrid
S
te
p
Hacks
,
BGrid
I
te
rator
Hacks
,
BGrid
S
te
p
Hacks
,
CGrid
I
te
rator
Hacks
,
CGrid
S
te
p
Hacks
,
AGridMoveSliceWindow
I
te
rator
Hacks
,
AGridMoveSliceWindow
S
te
p
Hacks
,
BGridMoveSliceWindow
I
te
rator
Hacks
,
BGridMoveSliceWindow
S
te
p
Hacks
,
ck
::
index_t
nrepeat
)
ck
::
index_t
nrepeat
)
{
{
using
namespace
ck
;
using
namespace
ck
;
...
@@ -70,7 +70,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
...
@@ -70,7 +70,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
// GEMM
// GEMM
using
GridwiseContraction
=
using
GridwiseContraction
=
Gridwise
Dynamic
ContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
<
GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
<
BlockSize
,
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
...
@@ -104,11 +104,11 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
...
@@ -104,11 +104,11 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
AGrid
I
te
rator
Hacks
,
AGrid
S
te
p
Hacks
,
BGrid
I
te
rator
Hacks
,
BGrid
S
te
p
Hacks
,
CGrid
I
te
rator
Hacks
,
CGrid
S
te
p
Hacks
,
AGridMoveSliceWindow
I
te
rator
Hacks
,
AGridMoveSliceWindow
S
te
p
Hacks
,
BGridMoveSliceWindow
I
te
rator
Hacks
>
;
BGridMoveSliceWindow
S
te
p
Hacks
>
;
const
auto
GK0
=
a_grid_desc_gk0_gm0_gm1_gk1
.
GetLength
(
I0
);
const
auto
GK0
=
a_grid_desc_gk0_gm0_gm1_gk1
.
GetLength
(
I0
);
...
@@ -116,7 +116,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
...
@@ -116,7 +116,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
a_grid_desc_gk0_gm0_gm1_gk1
,
b_grid_desc_gk0_gn0_gn1_gk1
,
c_grid_desc_gm0_gm1_gn0_gn1
))
a_grid_desc_gk0_gm0_gm1_gk1
,
b_grid_desc_gk0_gn0_gn1_gk1
,
c_grid_desc_gm0_gm1_gn0_gn1
))
{
{
throw
std
::
runtime_error
(
"wrong! "
throw
std
::
runtime_error
(
"wrong! "
"Gridwise
Dynamic
Contraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_"
"GridwiseContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_"
"GM0_GM1_GN0_GN1 has invalid setting"
);
"GM0_GM1_GN0_GN1 has invalid setting"
);
}
}
...
@@ -178,7 +178,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
...
@@ -178,7 +178,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
kernel_
dynamic_
contraction_dlops_v1r2
<
const
auto
kernel
=
kernel_contraction_dlops_v1r2
<
GridwiseContraction
,
GridwiseContraction
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
...
@@ -194,7 +194,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
...
@@ -194,7 +194,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
...
@@ -205,7 +204,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
...
@@ -205,7 +204,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
}
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
kernel_
dynamic_
contraction_dlops_v1r2
<
const
auto
kernel
=
kernel_contraction_dlops_v1r2
<
GridwiseContraction
,
GridwiseContraction
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
...
@@ -221,7 +220,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
...
@@ -221,7 +220,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
...
@@ -232,7 +230,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
...
@@ -232,7 +230,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
}
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
kernel_
dynamic_
contraction_dlops_v1r2
<
const
auto
kernel
=
kernel_contraction_dlops_v1r2
<
GridwiseContraction
,
GridwiseContraction
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
...
@@ -248,7 +246,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
...
@@ -248,7 +246,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
...
@@ -259,7 +256,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
...
@@ -259,7 +256,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
}
}
else
else
{
{
const
auto
kernel
=
kernel_
dynamic_
contraction_dlops_v1r2
<
const
auto
kernel
=
kernel_contraction_dlops_v1r2
<
GridwiseContraction
,
GridwiseContraction
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
...
@@ -275,7 +272,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
...
@@ -275,7 +272,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
...
...
host/driver_offline/include/driver_
dynamic_
convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
→
host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
View file @
ccc4a1d3
#ifndef DRIVER_
DYNAMIC_
CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
#define DRIVER_
DYNAMIC_
CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_
dynamic_
gemm_dlops_v2.hpp"
#include "gridwise_gemm_dlops_v2.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_operation_wrapper.hpp"
template
<
ck
::
index_t
BlockSize
,
template
<
ck
::
index_t
BlockSize
,
...
@@ -34,9 +34,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
...
@@ -34,9 +34,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
>
typename
InRightPads
>
__host__
void
Run
(
const
ck
::
Dynamic
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
__host__
void
Run
(
const
ck
::
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
ck
::
Dynamic
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
ck
::
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
ck
::
Dynamic
TensorDescriptor
<
Out
...
>&
out_n_k0_ho_wo_k1_global_desc
,
const
ck
::
TensorDescriptor
<
Out
...
>&
out_n_k0_ho_wo_k1_global_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InLeftPads
&
in_left_pads
,
...
@@ -82,14 +82,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
...
@@ -82,14 +82,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
// weight tensor
// weight tensor
const
auto
wei_e_k_global_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
wei_e_k_global_desc
=
transform_tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
K
,
C
*
Y
*
X
)),
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
*
Y
*
X
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
*
Y
*
X
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
*
Y
*
X
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
// input tensor
// input tensor
const
auto
in_n_c_hip_wip_global_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
in_n_c_hi_wi_global_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
make_pass_through_transform
(
C
),
...
@@ -98,7 +98,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
...
@@ -98,7 +98,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
in_n_c_hip_wip_global_desc
,
make_tuple
(
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
),
...
@@ -108,7 +108,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
...
@@ -108,7 +108,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
const
auto
in_e_n_ho_wo_global_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_e_n_ho_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
),
...
@@ -118,8 +118,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
...
@@ -118,8 +118,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
// output tensor
// output tensor
const
auto
out_k_n_ho_wo_global_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
out_k_n_ho_wo_global_desc
=
transform_tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
)),
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
K1
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
Ho
),
make_pass_through_transform
(
Ho
),
...
@@ -136,13 +136,13 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
...
@@ -136,13 +136,13 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
}
}
// hack to control index calculation when iterating over a_k_m_global tensor
// hack to control index calculation when iterating over a_k_m_global tensor
constexpr
auto
a_e_k_global_
i
te
rator
_hacks
=
constexpr
auto
a_e_k_global_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
constexpr
auto
a_e_k_global_move_slice_window_
i
te
rator
_hack
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
a_e_k_global_move_slice_window_
s
te
p
_hack
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
b_e_n_ho_wo_global_
i
te
rator
_hacks
=
constexpr
auto
b_e_n_ho_wo_global_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
...
@@ -152,12 +152,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
...
@@ -152,12 +152,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
constexpr
auto
b_e_n_ho_wo_global_move_slice_window_
i
te
rator
_hack
=
constexpr
auto
b_e_n_ho_wo_global_move_slice_window_
s
te
p
_hack
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format
// hack for NKHW format
constexpr
auto
c_k_n_ho_wo_global_tensor_
i
te
rator
_hacks
=
constexpr
auto
c_k_n_ho_wo_global_tensor_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
1
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
...
@@ -169,7 +169,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
...
@@ -169,7 +169,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
#if 1
#if 1
// GEMM
// GEMM
using
gridwise_gemm
=
Gridwise
Dynamic
GemmDlops_km_kn_mn_v3
<
using
gridwise_gemm
=
GridwiseGemmDlops_km_kn_mn_v3
<
BlockSize
,
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
...
@@ -202,11 +202,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
...
@@ -202,11 +202,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
Sequence
<
0
,
2
,
3
,
1
>
,
Sequence
<
0
,
2
,
3
,
1
>
,
0
,
0
,
CThreadTransferDstScalarPerVector_W
,
CThreadTransferDstScalarPerVector_W
,
decltype
(
a_e_k_global_
i
te
rator
_hacks
),
decltype
(
a_e_k_global_
s
te
p
_hacks
),
decltype
(
b_e_n_ho_wo_global_
i
te
rator
_hacks
),
decltype
(
b_e_n_ho_wo_global_
s
te
p
_hacks
),
decltype
(
c_k_n_ho_wo_global_tensor_
i
te
rator
_hacks
),
decltype
(
c_k_n_ho_wo_global_tensor_
s
te
p
_hacks
),
decltype
(
a_e_k_global_move_slice_window_
i
te
rator
_hack
),
decltype
(
a_e_k_global_move_slice_window_
s
te
p
_hack
),
decltype
(
b_e_n_ho_wo_global_move_slice_window_
i
te
rator
_hack
)
>
;
decltype
(
b_e_n_ho_wo_global_move_slice_window_
s
te
p
_hack
)
>
;
const
auto
GridSize
=
(
K
/
KPerBlock
)
*
(
Ho
/
HoPerBlock
)
*
(
Wo
/
WoPerBlock
)
*
N
;
const
auto
GridSize
=
(
K
/
KPerBlock
)
*
(
Ho
/
HoPerBlock
)
*
(
Wo
/
WoPerBlock
)
*
N
;
...
@@ -244,7 +244,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
...
@@ -244,7 +244,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
wei_e_k_global_desc
,
wei_e_k_global_desc
,
p_wei_global
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
in_e_n_ho_wo_global_desc
,
...
@@ -270,7 +269,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
...
@@ -270,7 +269,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
wei_e_k_global_desc
,
wei_e_k_global_desc
,
p_wei_global
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
in_e_n_ho_wo_global_desc
,
...
@@ -296,7 +294,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
...
@@ -296,7 +294,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
wei_e_k_global_desc
,
wei_e_k_global_desc
,
p_wei_global
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
in_e_n_ho_wo_global_desc
,
...
@@ -322,7 +319,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
...
@@ -322,7 +319,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
wei_e_k_global_desc
,
wei_e_k_global_desc
,
p_wei_global
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
in_e_n_ho_wo_global_desc
,
...
@@ -338,10 +334,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
...
@@ -338,10 +334,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)
calculate_convolution_flops
(
in_n_c_hi_wi_global_desc
,
float
perf
=
wei_k_c_y_x_global_desc
,
static_cast
<
float
>
(
calculate_convolution_flops
(
in_n_c_hi_wi_global_desc
,
out_n_k0_ho_wo_k1_global_desc
)
/
wei_k_c_y_x_global_desc
,
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
out_n_k0_ho_wo_k1_global_desc
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
<<
std
::
endl
;
...
...
host/driver_offline/include/driver_
dynamic_
convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp
→
host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp
View file @
ccc4a1d3
#ifndef DRIVER_
DYNAMIC_
CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP
#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP
#define DRIVER_
DYNAMIC_
CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP
#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_
dynamic_
gemm_dlops_v2.hpp"
#include "gridwise_gemm_dlops_v2.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_operation_wrapper.hpp"
template
<
ck
::
index_t
BlockSize
,
template
<
ck
::
index_t
BlockSize
,
...
@@ -34,9 +34,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -34,9 +34,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
>
typename
InRightPads
>
__host__
void
Run
(
const
ck
::
Dynamic
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
__host__
void
Run
(
const
ck
::
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
ck
::
Dynamic
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
ck
::
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
ck
::
Dynamic
TensorDescriptor
<
Out
...
>&
out_n_k0_ho_wo_k1_global_desc
,
const
ck
::
TensorDescriptor
<
Out
...
>&
out_n_k0_ho_wo_k1_global_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InLeftPads
&
in_left_pads
,
...
@@ -93,14 +93,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -93,14 +93,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
<<
std
::
endl
;
<<
std
::
endl
;
// weight tensor
// weight tensor
const
auto
wei_e_k_global_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
wei_e_k_global_desc
=
transform_tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
K
,
C
*
Y
*
X
)),
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
*
Y
*
X
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
*
Y
*
X
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
*
Y
*
X
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
// input tensor
// input tensor
const
auto
in_n_c_hip_wip_global_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
in_n_c_hi_wi_global_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
make_pass_through_transform
(
C
),
...
@@ -109,7 +109,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -109,7 +109,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
in_n_c_hip_wip_global_desc
,
make_tuple
(
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
),
...
@@ -119,7 +119,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -119,7 +119,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
const
auto
in_e_n_ho_wo_global_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
in_e_n_ho_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
),
...
@@ -129,8 +129,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -129,8 +129,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
// output tensor
// output tensor
const
auto
out_k_n_hop_wop_global_desc
=
transform_
dynamic_
tensor_descriptor
(
const
auto
out_k_n_hop_wop_global_desc
=
transform_tensor_descriptor
(
make_
dynamic_
naive_tensor_descriptor_packed
_v2
(
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
)),
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
K1
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
),
make_pad_transform
(
Ho
,
0
,
OutRightPadH
),
make_pad_transform
(
Ho
,
0
,
OutRightPadH
),
...
@@ -149,13 +149,13 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -149,13 +149,13 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
}
}
// hack to control index calculation when iterating over a_k_m_global tensor
// hack to control index calculation when iterating over a_k_m_global tensor
constexpr
auto
a_e_k_global_
i
te
rator
_hacks
=
constexpr
auto
a_e_k_global_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
constexpr
auto
a_e_k_global_move_slice_window_
i
te
rator
_hack
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
a_e_k_global_move_slice_window_
s
te
p
_hack
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
b_e_n_ho_wo_global_
i
te
rator
_hacks
=
constexpr
auto
b_e_n_ho_wo_global_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
...
@@ -165,12 +165,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -165,12 +165,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
constexpr
auto
b_e_n_ho_wo_global_move_slice_window_
i
te
rator
_hack
=
constexpr
auto
b_e_n_ho_wo_global_move_slice_window_
s
te
p
_hack
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format
// hack for NKHW format
constexpr
auto
c_k_n_ho_wo_global_tensor_
i
te
rator
_hacks
=
constexpr
auto
c_k_n_ho_wo_global_tensor_
s
te
p
_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
1
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
...
@@ -181,7 +181,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -181,7 +181,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
// GEMM
// GEMM
using
gridwise_gemm
=
Gridwise
Dynamic
GemmDlops_km_kn_mn_v3
<
using
gridwise_gemm
=
GridwiseGemmDlops_km_kn_mn_v3
<
BlockSize
,
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
...
@@ -214,11 +214,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -214,11 +214,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
Sequence
<
0
,
2
,
3
,
1
>
,
Sequence
<
0
,
2
,
3
,
1
>
,
0
,
0
,
CThreadTransferDstScalarPerVector_W
,
CThreadTransferDstScalarPerVector_W
,
decltype
(
a_e_k_global_
i
te
rator
_hacks
),
decltype
(
a_e_k_global_
s
te
p
_hacks
),
decltype
(
b_e_n_ho_wo_global_
i
te
rator
_hacks
),
decltype
(
b_e_n_ho_wo_global_
s
te
p
_hacks
),
decltype
(
c_k_n_ho_wo_global_tensor_
i
te
rator
_hacks
),
decltype
(
c_k_n_ho_wo_global_tensor_
s
te
p
_hacks
),
decltype
(
a_e_k_global_move_slice_window_
i
te
rator
_hack
),
decltype
(
a_e_k_global_move_slice_window_
s
te
p
_hack
),
decltype
(
b_e_n_ho_wo_global_move_slice_window_
i
te
rator
_hack
)
>
;
decltype
(
b_e_n_ho_wo_global_move_slice_window_
s
te
p
_hack
)
>
;
const
auto
GridSize
=
(
K
/
KPerBlock
)
*
(
Hop
/
HoPerBlock
)
*
(
Wop
/
WoPerBlock
)
*
N
;
const
auto
GridSize
=
(
K
/
KPerBlock
)
*
(
Hop
/
HoPerBlock
)
*
(
Wop
/
WoPerBlock
)
*
N
;
...
@@ -257,7 +257,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -257,7 +257,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
wei_e_k_global_desc
,
wei_e_k_global_desc
,
p_wei_global
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
in_e_n_ho_wo_global_desc
,
...
@@ -284,7 +283,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -284,7 +283,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
wei_e_k_global_desc
,
wei_e_k_global_desc
,
p_wei_global
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
in_e_n_ho_wo_global_desc
,
...
@@ -311,7 +309,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -311,7 +309,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
wei_e_k_global_desc
,
wei_e_k_global_desc
,
p_wei_global
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
in_e_n_ho_wo_global_desc
,
...
@@ -338,7 +335,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -338,7 +335,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
wei_e_k_global_desc
,
wei_e_k_global_desc
,
p_wei_global
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
in_e_n_ho_wo_global_desc
,
...
@@ -354,10 +350,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -354,10 +350,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)
calculate_convolution_flops
(
in_n_c_hi_wi_global_desc
,
float
perf
=
wei_k_c_y_x_global_desc
,
static_cast
<
float
>
(
calculate_convolution_flops
(
in_n_c_hi_wi_global_desc
,
out_n_k0_ho_wo_k1_global_desc
)
/
wei_k_c_y_x_global_desc
,
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
out_n_k0_ho_wo_k1_global_desc
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
<<
std
::
endl
;
...
...
host/driver_offline/include/driver_
dynamic_
gemm_dlops_v1r2.hpp
→
host/driver_offline/include/driver_gemm_dlops_v1r2.hpp
View file @
ccc4a1d3
#ifndef DRIVER_
DYNAMIC_
GEMM_DLOPS_V1R2
#ifndef DRIVER_GEMM_DLOPS_V1R2
#define DRIVER_
DYNAMIC_
GEMM_DLOPS_V1R2
#define DRIVER_GEMM_DLOPS_V1R2
#include "common_header.hpp"
#include "common_header.hpp"
#include "
dynamic_
tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "
dynamic_
tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_
dynamic_
gemm_dlops_v1r2.hpp"
#include "gridwise_gemm_dlops_v1r2.hpp"
template
<
ck
::
index_t
BlockSize
,
template
<
ck
::
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAB
,
...
@@ -43,23 +43,23 @@ template <ck::index_t BlockSize,
...
@@ -43,23 +43,23 @@ template <ck::index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferDstScalarPerVector
,
ck
::
index_t
CThreadTransferDstScalarPerVector
,
typename
AGrid
I
te
rator
Hacks
,
typename
AGrid
S
te
p
Hacks
,
typename
BGrid
I
te
rator
Hacks
,
typename
BGrid
S
te
p
Hacks
,
typename
CGrid
I
te
rator
Hacks
,
typename
CGrid
S
te
p
Hacks
,
typename
AGridMoveSliceWindow
I
te
rator
Hacks
,
typename
AGridMoveSliceWindow
S
te
p
Hacks
,
typename
BGridMoveSliceWindow
I
te
rator
Hacks
>
typename
BGridMoveSliceWindow
S
te
p
Hacks
>
__host__
float
driver_
dynamic_
gemm_dlops_v1r2
(
const
FloatAB
*
p_a_grid
,
__host__
float
driver_gemm_dlops_v1r2
(
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_grid
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
FloatC
*
p_c_grid
,
const
AKMGridDesc
&
a_k_m_grid_desc
,
const
AKMGridDesc
&
a_k_m_grid_desc
,
const
BKNGridDesc
&
b_k_n_grid_desc
,
const
BKNGridDesc
&
b_k_n_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
AGrid
I
te
rator
Hacks
,
AGrid
S
te
p
Hacks
,
BGrid
I
te
rator
Hacks
,
BGrid
S
te
p
Hacks
,
CGrid
I
te
rator
Hacks
,
CGrid
S
te
p
Hacks
,
AGridMoveSliceWindow
I
te
rator
Hacks
,
AGridMoveSliceWindow
S
te
p
Hacks
,
BGridMoveSliceWindow
I
te
rator
Hacks
,
BGridMoveSliceWindow
S
te
p
Hacks
,
ck
::
index_t
nrepeat
)
ck
::
index_t
nrepeat
)
{
{
using
namespace
ck
;
using
namespace
ck
;
...
@@ -72,49 +72,48 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
...
@@ -72,49 +72,48 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
// GEMM
// GEMM
using
GridwiseGemm
=
using
GridwiseGemm
=
GridwiseGemmDlops_km_kn_mn_v1r2
<
BlockSize
,
GridwiseDynamicGemmDlops_km_kn_mn_v1r2
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
FloatC
,
FloatC
,
CGlobalMemoryDataOperation
,
CGlobalMemoryDataOperation
,
AKMGridDesc
,
AKMGridDesc
,
BKNGridDesc
,
BKNGridDesc
,
CMNGridDesc
,
CMNGridDesc
,
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
KPerBlock
,
KPerBlock
,
M1PerThread
,
M1PerThread
,
N1PerThread
,
N1PerThread
,
KPerThread
,
KPerThread
,
M1N1ThreadClusterM10
,
M1N1ThreadClusterM10
,
M1N1ThreadClusterN10
,
M1N1ThreadClusterN10
,
M1N1ThreadClusterM11
,
M1N1ThreadClusterM11
,
M1N1ThreadClusterN11
,
M1N1ThreadClusterN11
,
ABlockTransferThreadSliceLengths_K_M0_M1
,
ABlockTransferThreadSliceLengths_K_M0_M1
,
ABlockTransferThreadClusterLengths_K_M0_M1
,
ABlockTransferThreadClusterLengths_K_M0_M1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_M1
,
ABlockTransferDstScalarPerVector_M1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
BBlockTransferThreadSliceLengths_K_N0_N1
,
BBlockTransferThreadSliceLengths_K_N0_N1
,
BBlockTransferThreadClusterLengths_K_N0_N1
,
BBlockTransferThreadClusterLengths_K_N0_N1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_N1
,
BBlockTransferDstScalarPerVector_N1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
AGridStepHacks
,
AGridIteratorHacks
,
BGridStepHacks
,
BGridIteratorHacks
,
CGridStepHacks
,
CGridIteratorHacks
,
AGridMoveSliceWindowStepHacks
,
AGridMoveSliceWindowIteratorHacks
,
BGridMoveSliceWindowStepHacks
>
;
BGridMoveSliceWindowIteratorHacks
>
;
const
auto
M
=
a_k_m_grid_desc
.
GetLength
(
I1
);
const
auto
M
=
a_k_m_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_grid_desc
.
GetLength
(
I1
);
...
@@ -122,8 +121,7 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
...
@@ -122,8 +121,7 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
if
(
!
GridwiseGemm
::
CheckValidity
(
a_k_m_grid_desc
,
b_k_n_grid_desc
,
c_m_n_grid_desc
))
if
(
!
GridwiseGemm
::
CheckValidity
(
a_k_m_grid_desc
,
b_k_n_grid_desc
,
c_m_n_grid_desc
))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! GridwiseGemmDlops_km_kn_mn_v1r2 has invalid setting"
);
"wrong! GridwiseDynamicGemmDlops_km_kn_mn_v1r2 has invalid setting"
);
}
}
const
auto
a_k_m0_m1_grid_desc
=
GridwiseGemm
::
MakeAKM0M1GridDescriptor
(
a_k_m_grid_desc
);
const
auto
a_k_m0_m1_grid_desc
=
GridwiseGemm
::
MakeAKM0M1GridDescriptor
(
a_k_m_grid_desc
);
...
@@ -174,22 +172,21 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
...
@@ -174,22 +172,21 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_
dynamic_
gemm_dlops_v1r2
<
GridwiseGemm
,
kernel_gemm_dlops_v1r2
<
GridwiseGemm
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CBlockIdToM0N0BlockClusterAdaptor
>
,
remove_reference_t
<
CBlockIdToM0N0BlockClusterAdaptor
>
,
true
,
true
,
true
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
...
@@ -201,22 +198,21 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
...
@@ -201,22 +198,21 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_
dynamic_
gemm_dlops_v1r2
<
GridwiseGemm
,
kernel_gemm_dlops_v1r2
<
GridwiseGemm
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CBlockIdToM0N0BlockClusterAdaptor
>
,
remove_reference_t
<
CBlockIdToM0N0BlockClusterAdaptor
>
,
true
,
true
,
false
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
...
@@ -228,22 +224,21 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
...
@@ -228,22 +224,21 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_
dynamic_
gemm_dlops_v1r2
<
GridwiseGemm
,
kernel_gemm_dlops_v1r2
<
GridwiseGemm
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CBlockIdToM0N0BlockClusterAdaptor
>
,
remove_reference_t
<
CBlockIdToM0N0BlockClusterAdaptor
>
,
false
,
false
,
true
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
...
@@ -255,22 +250,21 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
...
@@ -255,22 +250,21 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
else
else
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_
dynamic_
gemm_dlops_v1r2
<
GridwiseGemm
,
kernel_gemm_dlops_v1r2
<
GridwiseGemm
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CBlockIdToM0N0BlockClusterAdaptor
>
,
remove_reference_t
<
CBlockIdToM0N0BlockClusterAdaptor
>
,
false
,
false
,
false
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
...
@@ -299,15 +293,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
...
@@ -299,15 +293,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_
dynamic_
gemm_dlops_v1r2
<
GridwiseGemm
,
kernel_gemm_dlops_v1r2
<
GridwiseGemm
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CBlockIdToM0N0BlockClusterAdaptor
>
,
remove_reference_t
<
CBlockIdToM0N0BlockClusterAdaptor
>
,
true
,
true
,
true
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
kernel
,
kernel
,
...
@@ -315,27 +309,28 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
...
@@ -315,27 +309,28 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
(
void
CONSTANT
*
)
a_k_m0_m1_grid_desc_dev_buf
.
GetDeviceBuffer
(),
cast_pointer_to_constant_address_space
(
a_k_m0_m1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
(
void
CONSTANT
*
)
b_k_n0_n1_grid_desc_dev_buf
.
GetDeviceBuffer
(),
cast_pointer_to_constant_address_space
(
b_k_n0_n1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
(
void
CONSTANT
*
)
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf
.
GetDeviceBuffer
(),
cast_pointer_to_constant_address_space
(
(
void
CONSTANT
*
)
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
());
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_
dynamic_
gemm_dlops_v1r2
<
GridwiseGemm
,
kernel_gemm_dlops_v1r2
<
GridwiseGemm
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CBlockIdToM0N0BlockClusterAdaptor
>
,
remove_reference_t
<
CBlockIdToM0N0BlockClusterAdaptor
>
,
true
,
true
,
false
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
kernel
,
kernel
,
...
@@ -343,27 +338,28 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
...
@@ -343,27 +338,28 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
(
void
CONSTANT
*
)
a_k_m0_m1_grid_desc_dev_buf
.
GetDeviceBuffer
(),
cast_pointer_to_constant_address_space
(
a_k_m0_m1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
(
void
CONSTANT
*
)
b_k_n0_n1_grid_desc_dev_buf
.
GetDeviceBuffer
(),
cast_pointer_to_constant_address_space
(
b_k_n0_n1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
(
void
CONSTANT
*
)
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf
.
GetDeviceBuffer
(),
cast_pointer_to_constant_address_space
(
(
void
CONSTANT
*
)
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
());
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_
dynamic_
gemm_dlops_v1r2
<
GridwiseGemm
,
kernel_gemm_dlops_v1r2
<
GridwiseGemm
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CBlockIdToM0N0BlockClusterAdaptor
>
,
remove_reference_t
<
CBlockIdToM0N0BlockClusterAdaptor
>
,
false
,
false
,
true
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
kernel
,
kernel
,
...
@@ -371,27 +367,28 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
...
@@ -371,27 +367,28 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
(
void
CONSTANT
*
)
a_k_m0_m1_grid_desc_dev_buf
.
GetDeviceBuffer
(),
cast_pointer_to_constant_address_space
(
a_k_m0_m1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
(
void
CONSTANT
*
)
b_k_n0_n1_grid_desc_dev_buf
.
GetDeviceBuffer
(),
cast_pointer_to_constant_address_space
(
b_k_n0_n1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
(
void
CONSTANT
*
)
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf
.
GetDeviceBuffer
(),
cast_pointer_to_constant_address_space
(
(
void
CONSTANT
*
)
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
());
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
}
else
else
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_
dynamic_
gemm_dlops_v1r2
<
GridwiseGemm
,
kernel_gemm_dlops_v1r2
<
GridwiseGemm
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
AKM0M1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
BKN0N1GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CBlockIdToM0N0BlockClusterAdaptor
>
,
remove_reference_t
<
CBlockIdToM0N0BlockClusterAdaptor
>
,
false
,
false
,
false
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
kernel
,
kernel
,
...
@@ -399,14 +396,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
...
@@ -399,14 +396,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
(
void
CONSTANT
*
)
a_k_m0_m1_grid_desc_dev_buf
.
GetDeviceBuffer
(),
cast_pointer_to_constant_address_space
(
a_k_m0_m1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
(
void
CONSTANT
*
)
b_k_n0_n1_grid_desc_dev_buf
.
GetDeviceBuffer
(),
cast_pointer_to_constant_address_space
(
b_k_n0_n1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
(
void
CONSTANT
*
)
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf
.
GetDeviceBuffer
(),
cast_pointer_to_constant_address_space
(
(
void
CONSTANT
*
)
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
());
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
}
return
ave_time
;
return
ave_time
;
...
...
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment