Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5b1a9994
Commit
5b1a9994
authored
Oct 08, 2021
by
Jing Zhang
Browse files
add conv+add fusion
parent
484ae48f
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
2126 additions
and
13 deletions
+2126
-13
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2_add.hpp
...l/include/tensor_operation/gridwise_gemm_dlops_v2_add.hpp
+959
-0
host/driver_offline/CMakeLists.txt
host/driver_offline/CMakeLists.txt
+3
-0
host/driver_offline/include/device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
...ward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
+223
-0
host/driver_offline/include/driver_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
...ward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
+528
-0
host/driver_offline/src/conv_add_fwd_driver_offline_nchwc.cpp
.../driver_offline/src/conv_add_fwd_driver_offline_nchwc.cpp
+334
-0
host/driver_offline/src/conv_fwd_driver_offline_nchwc.cpp
host/driver_offline/src/conv_fwd_driver_offline_nchwc.cpp
+15
-13
host/host_tensor/include/host_conv.hpp
host/host_tensor/include/host_conv.hpp
+64
-0
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2_add.hpp
0 → 100644
View file @
5b1a9994
#ifndef CK_GRIDWISE_GEMM_V2_ADD_HPP
#define CK_GRIDWISE_GEMM_V2_ADD_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "blockwise_gemm_dlops_v3.hpp"
namespace
ck
{
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
,
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_dlops_v2_add
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_d_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_E0_E1_K0_K1_E2
a_e0_e1_k0_k1_e2_grid_desc
,
const
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
const
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
,
const
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
c_blockid_to_k_n_h_w_block_cluster_adaptor
)
{
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
Run
(
p_a_grid
,
p_b_grid
,
p_d_grid
,
p_c_grid
,
p_shared_block
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{});
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by CONSTANT void pointer
// CONSTANT is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_E0_E1_K0_K1_E2
,
typename
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
,
typename
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
,
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_dlops_v2_add
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_d_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
void
CONSTANT
*
p_a_e0_e1_k0_k1_e2_grid_desc
,
const
void
CONSTANT
*
p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
const
void
CONSTANT
*
p_d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
,
const
void
CONSTANT
*
p_c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
const
void
CONSTANT
*
p_c_blockid_to_k_n_h_w_block_cluster_adaptor
)
{
// first cast void CONSTANT void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const
auto
a_e0_e1_k0_k1_e2_grid_desc
=
*
reinterpret_cast
<
const
AGridDesc_E0_E1_K0_K1_E2
*>
(
cast_pointer_to_generic_address_space
(
p_a_e0_e1_k0_k1_e2_grid_desc
));
const
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
=
*
reinterpret_cast
<
const
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
*>
(
cast_pointer_to_generic_address_space
(
p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
));
const
auto
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
=
*
reinterpret_cast
<
const
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
*>
(
cast_pointer_to_generic_address_space
(
p_d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
));
const
auto
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
=
*
reinterpret_cast
<
const
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
*>
(
cast_pointer_to_generic_address_space
(
p_c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
));
const
auto
c_blockid_to_k_n_h_w_block_cluster_adaptor
=
*
reinterpret_cast
<
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
*>
(
cast_pointer_to_generic_address_space
(
p_c_blockid_to_k_n_h_w_block_cluster_adaptor
));
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
Run
(
p_a_grid
,
p_b_grid
,
p_d_grid
,
p_c_grid
,
p_shared_block
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{});
}
#endif
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
InMemoryDataOperationEnum_t
CGlobalMemoryDataOperation
,
typename
AGridDesc_E0_E1_K_E2
,
typename
BGridDesc_E0_E1_N_Ho_Wo_E2
,
typename
DGridDesc_K_N_Hox2_Wox2
,
typename
CGridDesc_K_N_Ho_Wo
,
index_t
E1_
,
index_t
E2_
,
index_t
K2_
,
index_t
KPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
E1PerBlock
,
index_t
KPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
,
index_t
EPerThread
,
typename
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2
,
typename
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_E2
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGlobalStepHacks
,
typename
BGlobalStepHacks
,
typename
CGlobalStepHacks
,
typename
DGlobalStepHacks
,
typename
AGlobalMoveSliceWindowStepHacks
,
typename
BGlobalMoveSliceWindowStepHacks
,
index_t
activ_type
=
0
>
struct
GridwiseGemmDlops_km_kn_mn_v3_add
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
E1
=
Number
<
E1_
>
{};
static
constexpr
auto
E2
=
Number
<
E2_
>
{};
static
constexpr
auto
K2
=
Number
<
K2_
>
{};
static
constexpr
auto
NPerBlock
=
I1
;
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
constexpr
auto
max_lds_align
=
Number
<
ABlockTransferDstScalarPerVector_E2
>
{};
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_e0_e1_k1_e2_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
I1
,
Number
<
E1
>
{},
Number
<
KPerBlock
>
{},
Number
<
E2
>
{}),
max_lds_align
);
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_e0_e1_k1_e2_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
return
a_block_space_size
*
sizeof
(
FloatAB
);
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_K_N_Ho_Wo
&
c_k_n_ho_wo_grid_desc
)
{
const
auto
K
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I1
);
const
auto
Ho
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I2
);
const
auto
Wo
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I3
);
const
auto
K0
=
K
/
KPerBlock
;
const
auto
N0
=
N
/
NPerBlock
;
const
auto
H0
=
Ho
/
HoPerBlock
;
const
auto
W0
=
Wo
/
WoPerBlock
;
const
index_t
grid_size
=
K0
*
N0
*
H0
*
W0
;
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainE0BlockLoop
(
const
index_t
E0
)
{
const
bool
has_main_e0_block_loop
=
E0
>
1
;
return
has_main_e0_block_loop
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainE1BlockLoop
()
{
const
bool
has_main_e1_block_loop
=
(
E1
+
E1PerBlock
)
/
(
2
*
E1PerBlock
)
>
1
;
return
has_main_e1_block_loop
;
}
__host__
__device__
static
constexpr
bool
CalculateHasDoubleTailE1BlockLoop
()
{
const
bool
has_double_tail_e1_block_loop
=
(
E1
/
E1PerBlock
)
%
2
==
0
;
return
has_double_tail_e1_block_loop
;
}
__host__
__device__
static
constexpr
auto
MakeAE0E1K0K1E2GridDescriptor
(
const
AGridDesc_E0_E1_K_E2
&
a_e0_e1_k_e2_grid_desc
)
{
const
auto
E0
=
a_e0_e1_k_e2_grid_desc
.
GetLength
(
I0
);
const
auto
K
=
a_e0_e1_k_e2_grid_desc
.
GetLength
(
I2
);
const
auto
K1
=
Number
<
KPerBlock
>
{};
const
auto
K0
=
K
/
K1
;
const
auto
a_e0_e1_k0_k1_e2_grid_desc
=
transform_tensor_descriptor
(
a_e0_e1_k_e2_grid_desc
,
make_tuple
(
make_pass_through_transform
(
E0
),
make_pass_through_transform
(
E1
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
E2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{}));
return
a_e0_e1_k0_k1_e2_grid_desc
;
}
__host__
__device__
static
constexpr
auto
MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor
(
const
BGridDesc_E0_E1_N_Ho_Wo_E2
&
b_e0_e1_n_ho_wo_e2_grid_desc
)
{
const
auto
E0
=
b_e0_e1_n_ho_wo_e2_grid_desc
.
GetLength
(
I0
);
// const auto E1 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I1);
const
auto
N
=
b_e0_e1_n_ho_wo_e2_grid_desc
.
GetLength
(
I2
);
const
auto
Ho
=
b_e0_e1_n_ho_wo_e2_grid_desc
.
GetLength
(
I3
);
const
auto
Wo
=
b_e0_e1_n_ho_wo_e2_grid_desc
.
GetLength
(
I4
);
// const auto E2 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I5);
const
auto
H2
=
Number
<
HoPerThread
>
{};
const
auto
H1
=
Number
<
HoPerBlock
/
HoPerThread
>
{};
const
auto
H0
=
Ho
/
(
H1
*
H2
);
const
auto
W2
=
Number
<
WoPerThread
>
{};
const
auto
W1
=
Number
<
WoPerBlock
/
WoPerThread
>
{};
const
auto
W0
=
Wo
/
(
W1
*
W2
);
const
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
=
transform_tensor_descriptor
(
b_e0_e1_n_ho_wo_e2_grid_desc
,
make_tuple
(
make_pass_through_transform
(
E0
),
make_pass_through_transform
(
E1
),
make_pass_through_transform
(
N
),
make_unmerge_transform
(
make_tuple
(
H0
,
H1
,
H2
)),
make_unmerge_transform
(
make_tuple
(
W0
,
W1
,
W2
)),
make_pass_through_transform
(
E2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
,
4
,
5
>
{},
Sequence
<
6
,
7
,
8
>
{},
Sequence
<
9
>
{}));
return
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
;
}
__host__
__device__
static
constexpr
auto
MakeCK0K1NH0H1H2W0W1W2GridDescriptor
(
const
CGridDesc_K_N_Ho_Wo
&
c_k_n_ho_wo_grid_desc
)
{
const
auto
K
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I1
);
const
auto
Ho
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I2
);
const
auto
Wo
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I3
);
const
auto
K1
=
Number
<
KPerBlock
>
{};
const
auto
K0
=
K
/
K1
;
const
auto
H2
=
Number
<
HoPerThread
>
{};
const
auto
H1
=
Number
<
HoPerBlock
/
HoPerThread
>
{};
const
auto
H0
=
Ho
/
(
H1
*
H2
);
const
auto
W2
=
Number
<
WoPerThread
>
{};
const
auto
W1
=
Number
<
WoPerBlock
/
WoPerThread
>
{};
const
auto
W0
=
Wo
/
(
W1
*
W2
);
const
auto
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
=
transform_tensor_descriptor
(
c_k_n_ho_wo_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
N
),
make_unmerge_transform
(
make_tuple
(
H0
,
H1
,
H2
)),
make_unmerge_transform
(
make_tuple
(
W0
,
W1
,
W2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
,
4
,
5
>
{},
Sequence
<
6
,
7
,
8
>
{}));
return
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
;
}
__host__
__device__
static
constexpr
auto
MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor
(
const
DGridDesc_K_N_Hox2_Wox2
&
d_k_n_hox2_wox2_grid_desc
)
{
const
auto
K
=
d_k_n_hox2_wox2_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
d_k_n_hox2_wox2_grid_desc
.
GetLength
(
I1
);
const
auto
Hox2
=
d_k_n_hox2_wox2_grid_desc
.
GetLength
(
I2
);
const
auto
Wox2
=
d_k_n_hox2_wox2_grid_desc
.
GetLength
(
I3
);
const
auto
K1
=
Number
<
KPerBlock
>
{};
const
auto
K0
=
K
/
K1
;
const
auto
HoPerBlockx2
=
HoPerBlock
*
2
;
const
auto
WoPerBlockx2
=
WoPerBlock
*
2
;
const
auto
HoPerThreadx2
=
HoPerThread
*
2
;
const
auto
WoPerThreadx2
=
WoPerThread
*
2
;
const
auto
H2x2
=
Number
<
HoPerThreadx2
>
{};
const
auto
H1
=
Number
<
HoPerBlockx2
/
HoPerThreadx2
>
{};
const
auto
H0
=
Hox2
/
(
H1
*
H2x2
);
const
auto
W2x2
=
Number
<
WoPerThreadx2
>
{};
const
auto
W1
=
Number
<
WoPerBlockx2
/
WoPerThreadx2
>
{};
const
auto
W0
=
Wox2
/
(
W1
*
W2x2
);
const
auto
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
=
transform_tensor_descriptor
(
d_k_n_hox2_wox2_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
N
),
make_unmerge_transform
(
make_tuple
(
H0
,
H1
,
H2x2
)),
make_unmerge_transform
(
make_tuple
(
W0
,
W1
,
W2x2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
,
4
,
5
>
{},
Sequence
<
6
,
7
,
8
>
{}));
return
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
;
}
__host__
__device__
static
constexpr
auto
MakeCBlockIdToKNHoWoBlockClusterAdaptor
(
const
CGridDesc_K_N_Ho_Wo
&
c_k_n_ho_wo_grid_desc
)
{
const
auto
K
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I1
);
const
auto
Ho
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I2
);
const
auto
Wo
=
c_k_n_ho_wo_grid_desc
.
GetLength
(
I3
);
const
auto
K0
=
K
/
KPerBlock
;
const
auto
N0
=
N
/
NPerBlock
;
const
auto
H0
=
Ho
/
HoPerBlock
;
const
auto
W0
=
Wo
/
WoPerBlock
;
const
auto
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
N0
,
H0
,
W0
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
;
}
using
AGridDesc_E0_E1_K0_K1_E2
=
decltype
(
MakeAE0E1K0K1E2GridDescriptor
(
AGridDesc_E0_E1_K_E2
{}));
using
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
=
decltype
(
MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor
(
BGridDesc_E0_E1_N_Ho_Wo_E2
{}));
using
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
=
decltype
(
MakeCK0K1NH0H1H2W0W1W2GridDescriptor
(
CGridDesc_K_N_Ho_Wo
{}));
using
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
=
decltype
(
MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor
(
DGridDesc_K_N_Hox2_Wox2
{}));
using
CBlockIdToBlockClusterAdaptor_K_N_H_W
=
decltype
(
MakeCBlockIdToKNHoWoBlockClusterAdaptor
(
CGridDesc_K_N_Ho_Wo
{}));
template
<
bool
HasMainE0BlockLoop
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_b_global
,
FloatC
*
__restrict__
p_d_global
,
FloatC
*
__restrict__
p_c_global
,
FloatAB
*
__restrict__
p_shared_block
,
const
AGridDesc_E0_E1_K0_K1_E2
&
a_e0_e1_k0_k1_e2_grid_desc
,
const
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
&
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
const
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
&
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
,
const
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
&
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
&
c_blockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
)
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_global
,
a_e0_e1_k0_k1_e2_grid_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_b_global
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_global
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
.
GetElementSpaceSize
());
(
void
)
c_global_buf
;
auto
d_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_d_global
,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
.
GetElementSpaceSize
());
constexpr
auto
HasMainE1BlockLoop
=
CalculateHasMainE1BlockLoop
();
constexpr
auto
HasDoubleTailE1BlockLoop
=
CalculateHasDoubleTailE1BlockLoop
();
// const auto Ho = b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetLength(I3);
// const auto Wo = b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetLength(I4);
const
auto
c_k_n_h_w_block_cluster_idx
=
c_blockid_to_k_n_h_w_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
const
index_t
k_block_work_id
=
__builtin_amdgcn_readfirstlane
(
c_k_n_h_w_block_cluster_idx
[
I0
]);
const
index_t
n_block_work_id
=
__builtin_amdgcn_readfirstlane
(
c_k_n_h_w_block_cluster_idx
[
I1
]);
const
index_t
ho_block_work_id
=
__builtin_amdgcn_readfirstlane
(
c_k_n_h_w_block_cluster_idx
[
I2
]);
const
index_t
wo_block_work_id
=
__builtin_amdgcn_readfirstlane
(
c_k_n_h_w_block_cluster_idx
[
I3
]);
constexpr
auto
max_lds_align
=
Number
<
ABlockTransferDstScalarPerVector_E2
>
{};
constexpr
auto
a_e1_k1_e2_block_gemm_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
E1PerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
E2
>
{}),
max_lds_align
);
constexpr
auto
b_e1_n_h_w_e2_block_gemm_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
E1PerBlock
>
{},
I1
,
Number
<
HoPerBlock
>
{},
Number
<
WoPerBlock
>
{},
Number
<
E2
>
{}));
constexpr
auto
c_k1_n_h2_w2_thread_gemm_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThread
>
{},
I1
,
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
auto
blockwise_gemm
=
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
decltype
(
a_e1_k1_e2_block_gemm_desc
),
decltype
(
b_e1_n_h_w_e2_block_gemm_desc
),
decltype
(
c_k1_n_h2_w2_thread_gemm_desc
),
EPerThread
,
K2
>
{};
auto
c_thread_mtx_index
=
blockwise_gemm
.
GetBeginOfCThreadDesc_K_N_Ho_Wo
(
get_thread_local_1d_id
());
const
auto
k_thread_id
=
c_thread_mtx_index
[
I0
];
const
auto
ho_thread_id
=
c_thread_mtx_index
[
I2
];
const
auto
wo_thread_id
=
c_thread_mtx_index
[
I3
];
constexpr
auto
a_e0_e1_k0_k1_e2_block_copy_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
I1
>
{},
Number
<
E1
>
{},
I1
,
Number
<
KPerBlock
>
{},
Number
<
E2
>
{}),
max_lds_align
);
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
I1
,
E1
,
I1
,
KPerBlock
,
E2
>
,
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2
,
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
a_e0_e1_k0_k1_e2_grid_desc
),
decltype
(
a_e0_e1_k0_k1_e2_block_copy_desc
),
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// ABlockTransferDstAccessOrder
ABlockTransferSrcVectorDim
,
4
,
// ABlockTransferDstVectorDim
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_E2
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
false
>
(
a_e0_e1_k0_k1_e2_grid_desc
,
make_multi_index
(
0
,
0
,
k_block_work_id
,
0
,
0
),
a_e0_e1_k0_k1_e2_block_copy_desc
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
));
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
I1
,
0
,
0
,
0
,
0
);
constexpr
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
E1PerBlock
>
{},
I1
,
I1
,
I1
,
Number
<
HoPerThread
>
{},
I1
,
I1
,
Number
<
WoPerThread
>
{},
Number
<
E2
>
{}));
auto
b_threadwise_transfer
=
ThreadwiseTensorSliceTransfer_v2
<
FloatAB
,
FloatAB
,
decltype
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
),
decltype
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc
),
Sequence
<
I1
,
E1PerBlock
,
I1
,
I1
,
I1
,
HoPerThread
,
I1
,
I1
,
WoPerThread
,
E2
>
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
make_multi_index
(
0
,
0
,
n_block_work_id
,
ho_block_work_id
,
ho_thread_id
,
0
,
wo_block_work_id
,
wo_thread_id
,
0
,
0
));
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_shared_block
,
a_e0_e1_k0_k1_e2_block_copy_desc
.
GetElementSpaceSize
());
// register allocation for output
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAcc
,
c_k1_n_h2_w2_thread_gemm_desc
.
GetElementSpaceSize
(),
true
>
c_thread_buf
;
#if 0
// initialize output thread tensor
ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_k1_n_h2_w2_thread_gemm_desc),
Sequence<KPerThread, NPerBlock, HoPerThread, WoPerThread>>{}
.Run(c_k1_n_h2_w2_thread_gemm_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
FloatAcc{0});
#endif
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
0
,
E1PerBlock
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_e0_e1_k_e2_global_step_hacks
=
AGlobalStepHacks
{};
constexpr
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks
=
BGlobalStepHacks
{};
// double regsiter buffer for b
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAB
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc
.
GetElementSpaceSize
(),
true
>
b_thread_even_buf
,
b_thread_odd_buf
;
if
constexpr
(
HasMainE0BlockLoop
)
{
const
auto
E0
=
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
.
GetLength
(
I0
);
index_t
e0_block_data_begin
=
0
;
do
{
// LDS double buffer: preload data
{
a_blockwise_copy
.
RunRead
(
a_e0_e1_k0_k1_e2_grid_desc
,
a_global_buf
,
a_e0_e1_k_e2_global_step_hacks
);
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_global_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks
);
a_blockwise_copy
.
RunWrite
(
a_e0_e1_k0_k1_e2_block_copy_desc
,
a_block_buf
);
}
__syncthreads
();
if
constexpr
(
HasMainE1BlockLoop
)
{
index_t
e1_block_data_begin
=
0
;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_thread_slice_copy_step
,
BGlobalMoveSliceWindowStepHacks
{});
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_global_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
(
make_tuple
(
E1PerBlock
,
0
,
0
));
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_thread_slice_copy_step
,
BGlobalMoveSliceWindowStepHacks
{});
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_global_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
(
make_tuple
(
E1PerBlock
,
0
,
0
));
e1_block_data_begin
+=
2
*
E1PerBlock
;
}
while
(
e1_block_data_begin
<
E1
-
2
*
E1PerBlock
);
}
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailE1BlockLoop
)
// if has 2 iteration left
{
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_thread_slice_copy_step
,
BGlobalMoveSliceWindowStepHacks
{});
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_global_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
(
make_tuple
(
E1PerBlock
,
0
,
0
));
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
}
else
// if has 1 iteration left
{
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
}
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_e0_e1_k0_k1_e2_grid_desc
,
a_block_slice_copy_step
,
AGlobalMoveSliceWindowStepHacks
{});
blockwise_gemm
.
MoveABlockSliceWindow
(
make_tuple
(
-
(
E1
-
E1PerBlock
),
0
,
0
));
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_thread_slice_copy_step
,
BGlobalMoveSliceWindowStepHacks
{});
e0_block_data_begin
+=
1
;
}
while
(
e0_block_data_begin
<
E0
);
}
else
{
// LDS double buffer: preload data
{
a_blockwise_copy
.
RunRead
(
a_e0_e1_k0_k1_e2_grid_desc
,
a_global_buf
,
a_e0_e1_k_e2_global_step_hacks
);
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_global_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks
);
a_blockwise_copy
.
RunWrite
(
a_e0_e1_k0_k1_e2_block_copy_desc
,
a_block_buf
);
}
__syncthreads
();
if
constexpr
(
HasMainE1BlockLoop
)
{
index_t
e1_block_data_begin
=
0
;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_thread_slice_copy_step
,
BGlobalMoveSliceWindowStepHacks
{});
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_global_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
(
make_tuple
(
E1PerBlock
,
0
,
0
));
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_thread_slice_copy_step
,
BGlobalMoveSliceWindowStepHacks
{});
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_global_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
(
make_tuple
(
E1PerBlock
,
0
,
0
));
e1_block_data_begin
+=
2
*
E1PerBlock
;
}
while
(
e1_block_data_begin
<
E1
-
2
*
E1PerBlock
);
}
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailE1BlockLoop
)
// if has 2 iteration left
{
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_thread_slice_copy_step
,
BGlobalMoveSliceWindowStepHacks
{});
b_threadwise_transfer
.
Run
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
b_global_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
(
make_tuple
(
E1PerBlock
,
0
,
0
));
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
}
else
// if has 1 iteration left
{
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
}
}
// activ
{
static_for
<
0
,
c_k1_n_h2_w2_thread_gemm_desc
.
GetElementSpaceSize
(),
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
activ_type
==
1
)
{
c_thread_buf
(
i
)
=
c_thread_buf
[
i
]
>=
0
?
c_thread_buf
[
i
]
:
0.0
;
}
else
if
constexpr
(
activ_type
==
2
)
{
FloatAcc
x
=
1.0
+
exp
(
-
c_thread_buf
[
i
]);
asm
volatile
(
"
\n
\
v_rcp_f32 %0, %1
\n
"
:
"=v"
(
x
)
:
"0"
(
x
));
c_thread_buf
(
i
)
=
x
;
}
});
}
// Resize_Add
{
constexpr
auto
HoPerThreadx2
=
HoPerThread
*
2
;
constexpr
auto
WoPerThreadx2
=
WoPerThread
*
2
;
constexpr
auto
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
KPerThread
>
{},
I1
,
I1
,
I1
,
Number
<
HoPerThreadx2
>
{},
I1
,
I1
,
Number
<
WoPerThreadx2
>
{}));
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatC
,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc
.
GetElementSpaceSize
(),
true
>
d_thread_buf
;
// hack to control index calculation when iterating over d_k_n_ho_wo_global tensor
constexpr
auto
d_k_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks
=
DGlobalStepHacks
{};
const
index_t
k_thread_data_on_global
=
k_thread_id
*
KPerThread
;
#if 1
ThreadwiseTensorSliceTransfer_v2
<
FloatC
,
FloatC
,
decltype
(
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
),
decltype
(
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc
),
Sequence
<
I1
,
KPerThread
,
I1
,
I1
,
I1
,
HoPerThreadx2
,
I1
,
I1
,
WoPerThreadx2
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
1
,
true
>
(
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
,
make_multi_index
(
k_block_work_id
,
k_thread_data_on_global
,
n_block_work_id
,
ho_block_work_id
,
ho_thread_id
,
0
,
wo_block_work_id
,
wo_thread_id
,
0
))
.
Run
(
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
,
d_global_buf
,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
d_thread_buf
,
d_k_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks
);
#endif
static_for
<
0
,
KPerThread
,
1
>
{}([
&
](
auto
k_i
)
{
static_for
<
0
,
HoPerThreadx2
,
1
>
{}([
&
](
auto
h_i
)
{
static_for
<
0
,
WoPerThreadx2
,
1
>
{}([
&
](
auto
w_i
)
{
d_thread_buf
(
Number
<
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
k_i
,
0
,
0
,
0
,
h_i
,
0
,
0
,
w_i
))
>
{})
=
1
;
// c_thread_buf[Number<c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset(
// make_tuple(k_i, 0, h_i / 2, w_i / 2))>{}];
});
});
});
ThreadwiseTensorSliceTransfer_v1r3
<
FloatC
,
FloatC
,
decltype
(
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc
),
decltype
(
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
),
Sequence
<
I1
,
KPerThread
,
I1
,
I1
,
I1
,
HoPerThreadx2
,
I1
,
I1
,
WoPerThreadx2
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
1
,
true
>
(
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
,
make_multi_index
(
k_block_work_id
,
k_thread_data_on_global
,
n_block_work_id
,
ho_block_work_id
,
ho_thread_id
,
0
,
wo_block_work_id
,
wo_thread_id
,
0
))
.
Run
(
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
d_thread_buf
,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
,
d_global_buf
,
d_k_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks
);
}
#if 1
// output: register to global memory
{
// hack to control index calculation when iterating over c_k_n_h0_h1_h2_w0_w1_w2_global
// tensor
constexpr
auto
c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks
=
CGlobalStepHacks
{};
constexpr
auto
c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
KPerThread
>
{},
I1
,
I1
,
I1
,
Number
<
HoPerThread
>
{},
I1
,
I1
,
Number
<
WoPerThread
>
{}));
const
index_t
k_thread_data_on_global
=
k_thread_id
*
KPerThread
;
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc
),
decltype
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
),
Sequence
<
I1
,
KPerThread
,
I1
,
I1
,
I1
,
HoPerThread
,
I1
,
I1
,
WoPerThread
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
1
,
true
>
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
make_multi_index
(
k_block_work_id
,
k_thread_data_on_global
,
n_block_work_id
,
ho_block_work_id
,
ho_thread_id
,
0
,
wo_block_work_id
,
wo_thread_id
,
0
))
.
Run
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
c_global_buf
,
c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks
);
}
#endif
}
};
}
// namespace ck
#endif
host/driver_offline/CMakeLists.txt
View file @
5b1a9994
...
...
@@ -13,18 +13,21 @@ include_directories(BEFORE
set
(
CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp
)
set
(
CONV_FWD_DRIVER_OFFLINE_NCHWC_SOURCE src/conv_fwd_driver_offline_nchwc.cpp
)
set
(
CONV_ADD_FWD_DRIVER_OFFLINE_NCHWC_SOURCE src/conv_add_fwd_driver_offline_nchwc.cpp
)
set
(
CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp
)
set
(
CONV_WRW_DRIVER_OFFLINE_SOURCE src/conv_wrw_driver_offline.cpp
)
set
(
GEMM_DRIVER_OFFLINE_SOURCE src/gemm_driver_offline.cpp
)
add_executable
(
conv_fwd_driver_offline
${
CONV_FWD_DRIVER_OFFLINE_SOURCE
}
)
add_executable
(
conv_fwd_driver_offline_nchwc
${
CONV_FWD_DRIVER_OFFLINE_NCHWC_SOURCE
}
)
add_executable
(
conv_add_fwd_driver_offline_nchwc
${
CONV_ADD_FWD_DRIVER_OFFLINE_NCHWC_SOURCE
}
)
add_executable
(
conv_bwd_driver_offline
${
CONV_BWD_DRIVER_OFFLINE_SOURCE
}
)
add_executable
(
conv_wrw_driver_offline
${
CONV_WRW_DRIVER_OFFLINE_SOURCE
}
)
add_executable
(
gemm_driver_offline
${
GEMM_DRIVER_OFFLINE_SOURCE
}
)
target_link_libraries
(
conv_fwd_driver_offline PRIVATE host_tensor
)
target_link_libraries
(
conv_fwd_driver_offline_nchwc PRIVATE host_tensor
)
target_link_libraries
(
conv_add_fwd_driver_offline_nchwc PRIVATE host_tensor
)
target_link_libraries
(
conv_bwd_driver_offline PRIVATE host_tensor
)
target_link_libraries
(
conv_wrw_driver_offline PRIVATE host_tensor
)
target_link_libraries
(
gemm_driver_offline PRIVATE host_tensor
)
host/driver_offline/include/device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
0 → 100644
View file @
5b1a9994
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp"
template
<
typename
TInWei
,
typename
TAcc
,
typename
TOut
,
ck
::
index_t
activ_type
,
typename
InLengths
,
typename
WeiLengths
,
typename
AddLengths
,
typename
OutLengths
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1
(
const
InLengths
&
in_n_c0_hi_wi_c1_lengths
,
const
WeiLengths
&
wei_k_c0_y_x_c1_lengths
,
const
AddLengths
&
add_n_k0_hox2_wox2_k1_lengths
,
const
OutLengths
&
out_n_k0_ho_wo_k1_lengths
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
Tensor
<
TInWei
>&
in_n_c0_hi_wi_c1
,
const
Tensor
<
TInWei
>&
wei_k_c0_y_x_c1
,
const
Tensor
<
TOut
>&
add_n_k0_hox2_wox2_k1
,
Tensor
<
TOut
>&
add_n_k0_hox2_wox2_k1_out
,
Tensor
<
TOut
>&
out_n_k0_ho_wo_k1
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
const
auto
N
=
out_n_k0_ho_wo_k1_lengths
[
I0
];
const
auto
K0
=
out_n_k0_ho_wo_k1_lengths
[
I1
];
const
auto
Ho
=
out_n_k0_ho_wo_k1_lengths
[
I2
];
const
auto
Wo
=
out_n_k0_ho_wo_k1_lengths
[
I3
];
const
auto
K1
=
out_n_k0_ho_wo_k1_lengths
[
I4
];
const
auto
C0
=
in_n_c0_hi_wi_c1_lengths
[
I1
];
const
auto
Hi
=
in_n_c0_hi_wi_c1_lengths
[
I2
];
const
auto
Wi
=
in_n_c0_hi_wi_c1_lengths
[
I3
];
const
auto
C1
=
in_n_c0_hi_wi_c1_lengths
[
I4
];
const
auto
K
=
wei_k_c0_y_x_c1_lengths
[
I0
];
const
auto
Y
=
wei_k_c0_y_x_c1_lengths
[
I2
];
const
auto
X
=
wei_k_c0_y_x_c1_lengths
[
I3
];
const
auto
Hox2
=
add_n_k0_hox2_wox2_k1_lengths
[
I2
];
const
auto
Wox2
=
add_n_k0_hox2_wox2_k1_lengths
[
I3
];
DeviceMem
in_n_c0_hi_wi_c1_device_buf
(
sizeof
(
TInWei
)
*
in_n_c0_hi_wi_c1
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c0_y_x_c1_device_buf
(
sizeof
(
TInWei
)
*
wei_k_c0_y_x_c1
.
mDesc
.
GetElementSpace
());
DeviceMem
add_n_k0_hox2_wox2_k1_device_buf
(
sizeof
(
TOut
)
*
add_n_k0_hox2_wox2_k1
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_k0_ho_wo_k1_device_buf
(
sizeof
(
TOut
)
*
out_n_k0_ho_wo_k1
.
mDesc
.
GetElementSpace
());
in_n_c0_hi_wi_c1_device_buf
.
ToDevice
(
in_n_c0_hi_wi_c1
.
mData
.
data
());
wei_k_c0_y_x_c1_device_buf
.
ToDevice
(
wei_k_c0_y_x_c1
.
mData
.
data
());
add_n_k0_hox2_wox2_k1_device_buf
.
ToDevice
(
add_n_k0_hox2_wox2_k1
.
mData
.
data
());
constexpr
index_t
InWeiVectorSize
=
8
;
if
(
C1
%
InWeiVectorSize
!=
0
)
{
throw
std
::
runtime_error
(
"wrong! C1 cannot be divided by InWeiVectorSize"
);
}
#if 0
constexpr index_t BlockSize = 256;
constexpr index_t KPerBlock = 32;
constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 64;
constexpr index_t E1 = C0 * 9;
constexpr index_t E2 = 1;
constexpr index_t E1PerBlock = C0;
constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 1;
using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, E2>;
using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = Sequence<1, E1PerBlock, KPerBlock, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2;
constexpr index_t CThreadTransferDstScalarPerVector_K = K1;
#elif
1
constexpr
auto
BlockSize
=
64
;
constexpr
auto
KPerBlock
=
16
;
constexpr
auto
HoPerBlock
=
8
;
constexpr
auto
WoPerBlock
=
32
;
constexpr
auto
E1
=
2
*
9
;
constexpr
auto
E2
=
1
;
constexpr
auto
K2
=
2
;
constexpr
auto
E1PerBlock
=
2
;
constexpr
auto
KPerThread
=
16
;
constexpr
auto
HoPerThread
=
2
;
constexpr
auto
WoPerThread
=
2
;
constexpr
auto
EPerThread
=
1
;
using
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2
=
Sequence
<
1
,
9
,
1
,
1
,
E2
>
;
using
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2
=
Sequence
<
1
,
E1PerBlock
,
1
,
KPerBlock
,
1
>
;
constexpr
auto
ABlockTransferSrcScalarPerVector_E2
=
E2
;
constexpr
auto
ABlockTransferDstScalarPerVector_E2
=
E2
;
constexpr
auto
BThreadTransferSrcScalarPerVector_E2
=
E2
;
constexpr
auto
CThreadTransferDstScalarPerVector_K
=
8
;
#endif
const
auto
in_n_c0_hi_wi_c1_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
C0
,
Hi
,
Wi
,
C1
));
const
auto
wei_k_c0_y_x_c1_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C0
,
Y
,
X
,
C1
));
const
auto
add_n_k0_hox2_wox2_k1_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K0
,
Hox2
,
Wox2
,
K1
));
const
auto
out_n_k0_ho_wo_k1_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
));
static_assert
(
in_n_c0_hi_wi_c1_desc
.
IsKnownAtCompileTime
(),
""
);
static_assert
(
wei_k_c0_y_x_c1_desc
.
IsKnownAtCompileTime
(),
""
);
static_assert
(
add_n_k0_hox2_wox2_k1_desc
.
IsKnownAtCompileTime
(),
""
);
static_assert
(
out_n_k0_ho_wo_k1_desc
.
IsKnownAtCompileTime
(),
""
);
constexpr
auto
conv_driver
=
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_add
<
BlockSize
,
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
,
TAcc
,
TOut
,
E1
,
E2
,
K2
,
KPerBlock
,
HoPerBlock
,
WoPerBlock
,
E1PerBlock
,
KPerThread
,
HoPerThread
,
WoPerThread
,
EPerThread
,
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2
,
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2
,
ABlockTransferSrcScalarPerVector_E2
,
ABlockTransferDstScalarPerVector_E2
,
BThreadTransferSrcScalarPerVector_E2
,
CThreadTransferDstScalarPerVector_K
,
activ_type
>
{};
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
const
auto
ave_time
=
conv_driver
.
Run
(
wei_k_c0_y_x_c1_desc
,
in_n_c0_hi_wi_c1_desc
,
add_n_k0_hox2_wox2_k1_desc
,
out_n_k0_ho_wo_k1_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
wei_k_c0_y_x_c1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
in_n_c0_hi_wi_c1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
add_n_k0_hox2_wox2_k1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k0_ho_wo_k1_device_buf
.
GetDeviceBuffer
()),
nrepeat
);
{
float
perf
=
static_cast
<
float
>
(
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C0
*
C1
*
Y
*
X
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
}
add_n_k0_hox2_wox2_k1_device_buf
.
ToDevice
(
add_n_k0_hox2_wox2_k1
.
mData
.
data
());
conv_driver
.
Run
(
wei_k_c0_y_x_c1_desc
,
in_n_c0_hi_wi_c1_desc
,
add_n_k0_hox2_wox2_k1_desc
,
out_n_k0_ho_wo_k1_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
wei_k_c0_y_x_c1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
in_n_c0_hi_wi_c1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
add_n_k0_hox2_wox2_k1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k0_ho_wo_k1_device_buf
.
GetDeviceBuffer
()),
1
);
add_n_k0_hox2_wox2_k1_device_buf
.
FromDevice
(
add_n_k0_hox2_wox2_k1_out
.
mData
.
data
());
out_n_k0_ho_wo_k1_device_buf
.
FromDevice
(
out_n_k0_ho_wo_k1
.
mData
.
data
());
}
host/driver_offline/include/driver_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
0 → 100644
View file @
5b1a9994
#ifndef DRIVER_CONVOLUTION_ADD_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP
#define DRIVER_CONVOLUTION_ADD_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_dlops_v2_add.hpp"
template
<
ck
::
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
ck
::
index_t
E1_
,
ck
::
index_t
E2_
,
ck
::
index_t
K2_
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
HoPerBlock
,
ck
::
index_t
WoPerBlock
,
ck
::
index_t
E1PerBlock
,
ck
::
index_t
KPerThread
,
ck
::
index_t
HoPerThread
,
ck
::
index_t
WoPerThread
,
ck
::
index_t
EPerThread
,
typename
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2
,
typename
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2
,
ck
::
index_t
ABlockTransferSrcScalarPerVector_E2
,
ck
::
index_t
ABlockTransferDstScalarPerVector_E2
,
ck
::
index_t
BThreadTransferSrcScalarPerVector_E2
,
ck
::
index_t
CThreadTransferDstScalarPerVector_K
,
ck
::
index_t
activ_type
>
struct
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_add
{
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Add
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
__host__
float
Run
(
const
ck
::
TensorDescriptor
<
Wei
...
>&
wei_k_c0_y_x_c1_global_desc
,
const
ck
::
TensorDescriptor
<
In
...
>&
in_n_c0_hi_wi_c1_global_desc
,
const
ck
::
TensorDescriptor
<
Add
...
>&
add_n_k0_hox2_wox2_k1_global_desc
,
const
ck
::
TensorDescriptor
<
Out
...
>&
out_n_k0_ho_wo_k1_global_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_d_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
int
nrepeat
)
const
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
const
auto
N
=
in_n_c0_hi_wi_c1_global_desc
.
GetLength
(
I0
);
const
auto
C0
=
in_n_c0_hi_wi_c1_global_desc
.
GetLength
(
I1
);
const
auto
Hi
=
in_n_c0_hi_wi_c1_global_desc
.
GetLength
(
I2
);
const
auto
Wi
=
in_n_c0_hi_wi_c1_global_desc
.
GetLength
(
I3
);
// const auto C1 = in_n_c0_hi_wi_c1_global_desc.GetLength(I4);
const
auto
K0
=
out_n_k0_ho_wo_k1_global_desc
.
GetLength
(
I1
);
const
auto
Ho
=
out_n_k0_ho_wo_k1_global_desc
.
GetLength
(
I2
);
const
auto
Wo
=
out_n_k0_ho_wo_k1_global_desc
.
GetLength
(
I3
);
const
auto
K1
=
out_n_k0_ho_wo_k1_global_desc
.
GetLength
(
I4
);
const
auto
Hox2
=
add_n_k0_hox2_wox2_k1_global_desc
.
GetLength
(
I2
);
const
auto
Wox2
=
add_n_k0_hox2_wox2_k1_global_desc
.
GetLength
(
I3
);
const
auto
K
=
wei_k_c0_y_x_c1_global_desc
.
GetLength
(
I0
);
const
auto
Y
=
wei_k_c0_y_x_c1_global_desc
.
GetLength
(
I2
);
const
auto
X
=
wei_k_c0_y_x_c1_global_desc
.
GetLength
(
I3
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
#if 1
const
auto
Hop
=
Number
<
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
*
HoPerBlock
>
{};
const
auto
Wop
=
Number
<
(
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
*
WoPerBlock
>
{};
#else
const
auto
Hop
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
*
HoPerBlock
;
const
auto
Wop
=
(
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
*
WoPerBlock
;
#endif
const
auto
OutRightPadH
=
Hop
-
Ho
;
const
auto
OutRightPadW
=
Wop
-
Wo
;
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
]
+
OutRightPadH
*
ConvStrideH
;
const
auto
InRightPadW
=
in_right_pads
[
I1
]
+
OutRightPadW
*
ConvStrideW
;
const
auto
E
=
C0
*
Y
*
X
;
constexpr
auto
E1
=
Number
<
E1_
>
{};
constexpr
auto
E2
=
Number
<
E2_
>
{};
constexpr
auto
K2
=
Number
<
K2_
>
{};
const
auto
E0
=
E
/
E1
;
// weight tensor
const
auto
a_e_k_e2_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C0
*
Y
*
X
,
E2
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C0
*
Y
*
X
),
make_pass_through_transform
(
E2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{},
Sequence
<
2
>
{}));
const
auto
a_e0_e1_k_e2_grid_desc
=
transform_tensor_descriptor
(
a_e_k_e2_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
E0
,
E1
)),
make_pass_through_transform
(
K
),
make_pass_through_transform
(
E2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
// input tensor
const
auto
in_n_c0_hip_wip_e2_global_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C0
,
Hi
,
Wi
,
E2
)),
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C0
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
E2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_c0_y_ho_x_wo_e2_global_desc
=
transform_tensor_descriptor
(
in_n_c0_hip_wip_e2_global_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C0
),
make_embed_transform
(
make_tuple
(
Y
,
Hop
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wop
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
E2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{},
Sequence
<
6
>
{}));
const
auto
b_e_n_ho_wo_e2_grid_desc
=
transform_tensor_descriptor
(
in_n_c0_y_ho_x_wo_e2_global_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C0
,
Y
,
X
)),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
Hop
),
make_pass_through_transform
(
Wop
),
make_pass_through_transform
(
E2
)),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
b_e0_e1_n_ho_wo_e2_grid_desc
=
transform_tensor_descriptor
(
b_e_n_ho_wo_e2_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
E0
,
E1
)),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
Hop
),
make_pass_through_transform
(
Wop
),
make_pass_through_transform
(
E2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}));
// output tensor
const
auto
c_k_n_hop_wop_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
N
),
make_pad_transform
(
Ho
,
I0
,
OutRightPadH
),
make_pad_transform
(
Wo
,
I0
,
OutRightPadW
)),
make_tuple
(
Sequence
<
1
,
4
>
{},
Sequence
<
0
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
// add tensor
const
auto
d_k_n_hopx2_wopx2_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K0
,
Hox2
,
Wox2
,
K1
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
N
),
make_pad_transform
(
Hox2
,
I0
,
Number
<
OutRightPadH
*
2
>
{}),
make_pad_transform
(
Wox2
,
I0
,
Number
<
OutRightPadW
*
2
>
{})),
make_tuple
(
Sequence
<
1
,
4
>
{},
Sequence
<
0
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
std
::
cerr
<<
"Hop = "
<<
Hop
<<
" Wop = "
<<
Wop
<<
std
::
endl
;
if
(
!
((
K
%
KPerBlock
)
==
0
&&
(
Hop
%
HoPerBlock
)
==
0
&&
(
Wop
%
WoPerBlock
)
==
0
&&
(
E1
%
E1PerBlock
)
==
0
))
{
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
}
// clang-format off
// hack to control index calculation when iterating over a_e0_e1_k_e2_global tensor
constexpr
auto
a_e0_e1_k_e2_global_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
constexpr
auto
a_e0_e1_k_e2_global_move_slice_window_step_hack
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global tensor
constexpr
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{})
);
constexpr
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over c_k0_k1_n_h0_h1_h2_w0_w1_w2_global tensor
constexpr
auto
c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
constexpr
auto
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// clang-format on
static_assert
(
a_e0_e1_k_e2_grid_desc
.
IsKnownAtCompileTime
(),
""
);
static_assert
(
b_e0_e1_n_ho_wo_e2_grid_desc
.
IsKnownAtCompileTime
(),
""
);
static_assert
(
d_k_n_hopx2_wopx2_grid_desc
.
IsKnownAtCompileTime
(),
""
);
static_assert
(
c_k_n_hop_wop_grid_desc
.
IsKnownAtCompileTime
(),
""
);
// GEMM
using
GridwiseGemm
=
GridwiseGemmDlops_km_kn_mn_v3_add
<
BlockSize
,
FloatAB
,
FloatAcc
,
FloatC
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
a_e0_e1_k_e2_grid_desc
),
decltype
(
b_e0_e1_n_ho_wo_e2_grid_desc
),
decltype
(
d_k_n_hopx2_wopx2_grid_desc
),
decltype
(
c_k_n_hop_wop_grid_desc
),
E1
,
E2
,
K2
,
KPerBlock
,
HoPerBlock
,
WoPerBlock
,
E1PerBlock
,
KPerThread
,
HoPerThread
,
WoPerThread
,
EPerThread
,
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2
,
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2
,
Sequence
<
2
,
3
,
0
,
1
,
4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
4
,
ABlockTransferSrcScalarPerVector_E2
,
ABlockTransferDstScalarPerVector_E2
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
// E0, E1, N, H0, H1, H2, W0, W1, W2, E2
9
,
BThreadTransferSrcScalarPerVector_E2
,
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
>
,
// K0, K1, N, H0, H1, H2, W0, W1, W2
1
,
CThreadTransferDstScalarPerVector_K
,
decltype
(
a_e0_e1_k_e2_global_step_hacks
),
decltype
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks
),
decltype
(
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks
),
decltype
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks
),
decltype
(
a_e0_e1_k_e2_global_move_slice_window_step_hack
),
decltype
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack
),
activ_type
>
;
const
auto
a_e0_e1_k0_k1_e2_grid_desc
=
GridwiseGemm
::
MakeAE0E1K0K1E2GridDescriptor
(
a_e0_e1_k_e2_grid_desc
);
const
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
=
GridwiseGemm
::
MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor
(
b_e0_e1_n_ho_wo_e2_grid_desc
);
const
auto
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
=
GridwiseGemm
::
MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor
(
d_k_n_hopx2_wopx2_grid_desc
);
const
auto
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
=
GridwiseGemm
::
MakeCK0K1NH0H1H2W0W1W2GridDescriptor
(
c_k_n_hop_wop_grid_desc
);
using
AGridDesc_E0_E1_K0_K1_E2
=
decltype
(
a_e0_e1_k0_k1_e2_grid_desc
);
using
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
=
decltype
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
);
using
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
=
decltype
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
);
using
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
=
decltype
(
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
);
const
auto
grid_size
=
(
K
/
KPerBlock
)
*
(
Hop
/
HoPerBlock
)
*
(
Wop
/
WoPerBlock
)
*
N
;
const
bool
has_main_e0_block_loop
=
E0
>
1
;
std
::
cerr
<<
"has_main_e0_block_loop = "
<<
has_main_e0_block_loop
<<
std
::
endl
;
const
auto
c_blockid_to_k_n_h_w_block_cluster_adaptor
=
GridwiseGemm
::
MakeCBlockIdToKNHoWoBlockClusterAdaptor
(
c_k_n_hop_wop_grid_desc
);
using
CBlockIdToBlockClusterAdaptor_K_N_H_W
=
decltype
(
c_blockid_to_k_n_h_w_block_cluster_adaptor
);
float
ave_time
=
0
;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
if
(
has_main_e0_block_loop
)
{
const
auto
kernel
=
kernel_gemm_dlops_v2_add
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K0_K1_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
>
,
remove_reference_t
<
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
>
,
remove_reference_t
<
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_H_W
>
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_d_grid
,
p_c_grid
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
c_blockid_to_k_n_h_w_block_cluster_adaptor
);
}
else
{
const
auto
kernel
=
kernel_gemm_dlops_v2_add
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K0_K1_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
>
,
remove_reference_t
<
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
>
,
remove_reference_t
<
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_H_W
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_d_grid
,
p_c_grid
,
a_e0_e1_k0_k1_e2_grid_desc
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
c_blockid_to_k_n_h_w_block_cluster_adaptor
);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem
a_e0_e1_k0_k1_e2_grid_desc_dev_buf
(
sizeof
(
AGridDesc_E0_E1_K0_K1_E2
));
DeviceMem
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf
(
sizeof
(
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
));
DeviceMem
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf
(
sizeof
(
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
));
DeviceMem
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf
(
sizeof
(
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
));
DeviceMem
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf
(
sizeof
(
CBlockIdToBlockClusterAdaptor_K_N_H_W
));
a_e0_e1_k0_k1_e2_grid_desc_dev_buf
.
ToDevice
(
&
a_e0_e1_k0_k1_e2_grid_desc
);
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf
.
ToDevice
(
&
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
);
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf
.
ToDevice
(
&
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc
);
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf
.
ToDevice
(
&
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
);
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf
.
ToDevice
(
&
c_blockid_to_k_n_h_w_block_cluster_adaptor
);
if
(
has_main_e0_block_loop
)
{
const
auto
kernel
=
kernel_gemm_dlops_v2_add
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K0_K1_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
>
,
remove_reference_t
<
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
>
,
remove_reference_t
<
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_H_W
>
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_d_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_e0_e1_k0_k1_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
else
{
const
auto
kernel
=
kernel_gemm_dlops_v2_add
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K0_K1_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2
>
,
remove_reference_t
<
DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2
>
,
remove_reference_t
<
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_H_W
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_d_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_e0_e1_k0_k1_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
#endif
return
ave_time
;
}
};
#endif
host/driver_offline/src/conv_add_fwd_driver_offline_nchwc.cpp
0 → 100644
View file @
5b1a9994
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
//#include <half.hpp>
#include "config.hpp"
#include "debug.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "conv_common.hpp"
#include "host_conv.hpp"
#include "device_tensor.hpp"
#include "device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp"
#define USE_DYNAMIC_MODE 0
#define USE_CONV_FWD_V5R1_NCHWC 1
enum
ConvForwardAlgo
{
V5R1NCHWC
// 0
};
int
main
(
int
argc
,
char
*
argv
[])
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
#if USE_DYNAMIC_MODE
// dynamic mode
if
(
argc
!=
23
)
{
printf
(
"arg1 to 5: algo, do_verification, init_method, do_log, nrepeat
\n
"
);
printf
(
"rest: N, K0, K1, C0, C1, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx
\n
"
);
exit
(
1
);
}
constexpr
index_t
activ_type
=
0
;
const
ConvForwardAlgo
algo
=
static_cast
<
ConvForwardAlgo
>
(
std
::
stoi
(
argv
[
1
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
2
]);
const
int
init_method
=
std
::
stoi
(
argv
[
3
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
4
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
5
]);
const
index_t
N
=
std
::
stoi
(
argv
[
6
]);
const
index_t
K0
=
std
::
stoi
(
argv
[
7
]);
const
index_t
K1
=
std
::
stoi
(
argv
[
8
]);
const
index_t
C0
=
std
::
stoi
(
argv
[
9
]);
const
index_t
C1
=
std
::
stoi
(
argv
[
10
]);
const
index_t
Y
=
std
::
stoi
(
argv
[
11
]);
const
index_t
X
=
std
::
stoi
(
argv
[
12
]);
const
index_t
Hi
=
std
::
stoi
(
argv
[
13
]);
const
index_t
Wi
=
std
::
stoi
(
argv
[
14
]);
const
index_t
conv_stride_h
=
std
::
stoi
(
argv
[
15
]);
const
index_t
conv_stride_w
=
std
::
stoi
(
argv
[
16
]);
const
index_t
conv_dilation_h
=
std
::
stoi
(
argv
[
17
]);
const
index_t
conv_dilation_w
=
std
::
stoi
(
argv
[
18
]);
const
index_t
in_left_pad_h
=
std
::
stoi
(
argv
[
19
]);
const
index_t
in_left_pad_w
=
std
::
stoi
(
argv
[
20
]);
const
index_t
in_right_pad_h
=
std
::
stoi
(
argv
[
21
]);
const
index_t
in_right_pad_w
=
std
::
stoi
(
argv
[
22
]);
const
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
const
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
const
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
YEff
)
/
conv_stride_h
+
1
;
const
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
1
;
#else
// static mode
if
(
argc
<
6
)
{
printf
(
"arg1 to 5: algo, do_verification, init_method, do_log, nrepeat
\n
"
);
exit
(
1
);
}
const
ConvForwardAlgo
algo
=
static_cast
<
ConvForwardAlgo
>
(
std
::
stoi
(
argv
[
1
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
2
]);
const
int
init_method
=
std
::
stoi
(
argv
[
3
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
4
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
5
]);
constexpr
index_t
activ_type
=
0
;
#if 0
constexpr auto N = Number<1>{};
constexpr auto Hi = Number<1080>{};
constexpr auto Wi = Number<1920>{};
constexpr auto Y = Number<3>{};
constexpr auto X = Number<3>{};
constexpr auto C0 = Number<2>{};
constexpr auto C1 = Number<8>{};
constexpr auto K1 = Number<8>{};
constexpr auto K0 = Number<8>{};
#elif
0
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
Hi
=
Number
<
540
>
{};
constexpr
auto
Wi
=
Number
<
960
>
{};
constexpr
auto
Y
=
Number
<
3
>
{};
constexpr
auto
X
=
Number
<
3
>
{};
constexpr
auto
C0
=
Number
<
2
>
{};
constexpr
auto
C1
=
Number
<
8
>
{};
constexpr
auto
K1
=
Number
<
8
>
{};
constexpr
auto
K0
=
Number
<
8
>
{};
#elif 0
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
Hi
=
Number
<
270
>
{};
constexpr
auto
Wi
=
Number
<
480
>
{};
constexpr
auto
Y
=
Number
<
3
>
{};
constexpr
auto
X
=
Number
<
3
>
{};
constexpr
auto
C0
=
Number
<
2
>
{};
constexpr
auto
C1
=
Number
<
8
>
{};
constexpr
auto
K1
=
Number
<
8
>
{};
constexpr
auto
K0
=
Number
<
8
>
{};
#elif 1
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
Hi
=
Number
<
135
>
{};
constexpr
auto
Wi
=
Number
<
240
>
{};
constexpr
auto
Y
=
Number
<
3
>
{};
constexpr
auto
X
=
Number
<
3
>
{};
constexpr
auto
C0
=
Number
<
2
>
{};
constexpr
auto
C1
=
Number
<
8
>
{};
constexpr
auto
K1
=
Number
<
8
>
{};
constexpr
auto
K0
=
Number
<
8
>
{};
#endif
constexpr
auto
conv_stride_h
=
I1
;
constexpr
auto
conv_stride_w
=
I1
;
constexpr
auto
conv_dilation_h
=
I1
;
constexpr
auto
conv_dilation_w
=
I1
;
constexpr
auto
in_left_pad_h
=
I1
;
constexpr
auto
in_left_pad_w
=
I1
;
constexpr
auto
in_right_pad_h
=
I1
;
constexpr
auto
in_right_pad_w
=
I1
;
constexpr
auto
YEff
=
(
Y
-
I1
)
*
conv_dilation_h
+
I1
;
constexpr
auto
XEff
=
(
X
-
I1
)
*
conv_dilation_w
+
I1
;
constexpr
auto
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
YEff
)
/
conv_stride_h
+
I1
;
constexpr
auto
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
I1
;
constexpr
auto
Hox2
=
Number
<
Ho
*
2
>
{};
constexpr
auto
Wox2
=
Number
<
Wo
*
2
>
{};
#endif
#if 0
using in_data_t = float;
using acc_data_t = float;
using out_data_t = float;
#elif
1
using
in_data_t
=
half_t
;
using
acc_data_t
=
float
;
using
out_data_t
=
half_t
;
#elif 1
using
in_data_t
=
int8_t
;
using
acc_data_t
=
int32_t
;
using
out_data_t
=
int8_t
;
#endif
std
::
vector
<
std
::
size_t
>
in_lengths_host
(
5
),
wei_lengths_host
(
5
),
out_lengths_host
(
5
),
add_lengths_host
(
5
);
in_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
in_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
C0
);
in_lengths_host
[
2
]
=
static_cast
<
std
::
size_t
>
(
Hi
);
in_lengths_host
[
3
]
=
static_cast
<
std
::
size_t
>
(
Wi
);
in_lengths_host
[
4
]
=
static_cast
<
std
::
size_t
>
(
C1
);
wei_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K0
*
K1
);
wei_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
C0
);
wei_lengths_host
[
2
]
=
static_cast
<
std
::
size_t
>
(
Y
);
wei_lengths_host
[
3
]
=
static_cast
<
std
::
size_t
>
(
X
);
wei_lengths_host
[
4
]
=
static_cast
<
std
::
size_t
>
(
C1
);
out_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
out_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
K0
);
out_lengths_host
[
2
]
=
static_cast
<
std
::
size_t
>
(
Ho
);
out_lengths_host
[
3
]
=
static_cast
<
std
::
size_t
>
(
Wo
);
out_lengths_host
[
4
]
=
static_cast
<
std
::
size_t
>
(
K1
);
add_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
add_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
K0
);
add_lengths_host
[
2
]
=
static_cast
<
std
::
size_t
>
(
Hox2
);
add_lengths_host
[
3
]
=
static_cast
<
std
::
size_t
>
(
Wox2
);
add_lengths_host
[
4
]
=
static_cast
<
std
::
size_t
>
(
K1
);
Tensor
<
in_data_t
>
in
(
in_lengths_host
);
Tensor
<
in_data_t
>
wei
(
wei_lengths_host
);
Tensor
<
in_data_t
>
add
(
add_lengths_host
);
Tensor
<
out_data_t
>
out_host
(
out_lengths_host
);
Tensor
<
out_data_t
>
out_device
(
out_lengths_host
);
Tensor
<
in_data_t
>
add_device
(
add_lengths_host
);
Tensor
<
in_data_t
>
add_host
(
add_lengths_host
);
ostream_HostTensorDescriptor
(
in
.
mDesc
,
std
::
cout
<<
"in: "
);
ostream_HostTensorDescriptor
(
wei
.
mDesc
,
std
::
cout
<<
"wei: "
);
ostream_HostTensorDescriptor
(
add
.
mDesc
,
std
::
cout
<<
"add: "
);
ostream_HostTensorDescriptor
(
out_host
.
mDesc
,
std
::
cout
<<
"out: "
);
print_array
(
"InLeftPads"
,
make_tuple
(
in_left_pad_h
,
in_left_pad_w
));
print_array
(
"InRightPads"
,
make_tuple
(
in_right_pad_h
,
in_right_pad_w
));
print_array
(
"ConvStrides"
,
make_tuple
(
conv_stride_h
,
conv_stride_w
));
print_array
(
"ConvDilations"
,
make_tuple
(
conv_dilation_h
,
conv_dilation_w
));
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
switch
(
init_method
)
{
case
0
:
// no initialization
break
;
case
1
:
in
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
add
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
wei
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
break
;
case
2
:
in
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
add
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
wei
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
break
;
case
3
:
in
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
add
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
wei
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
break
;
case
4
:
in
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
add
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
wei
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
break
;
case
5
:
in
.
GenerateTensorValue
(
GeneratorTensor_3
<
float
>
{
0.0
,
1.0
},
num_thread
);
add
.
GenerateTensorValue
(
GeneratorTensor_3
<
float
>
{
0.0
,
1.0
},
num_thread
);
wei
.
GenerateTensorValue
(
GeneratorTensor_3
<
float
>
{
-
0.5
,
0.5
},
num_thread
);
break
;
default:
in
.
GenerateTensorValue
(
GeneratorTensor_2
{
1
,
5
},
num_thread
);
auto
gen_wei
=
[](
auto
...
is
)
{
return
GeneratorTensor_2
{
1
,
5
}(
is
...)
*
GeneratorTensor_Checkboard
{}(
is
...);
};
wei
.
GenerateTensorValue
(
gen_wei
,
num_thread
);
}
auto
f_make_for_device_nchwc
=
[
&
]()
{
const
auto
in_lengths_dev
=
make_tuple
(
N
,
C0
,
Hi
,
Wi
,
C1
);
const
auto
wei_lengths_dev
=
make_tuple
(
K0
*
K1
,
C0
,
Y
,
X
,
C1
);
const
auto
add_lengths_dev
=
make_tuple
(
N
,
K0
,
Hox2
,
Wox2
,
K1
);
const
auto
out_lengths_dev
=
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
);
const
auto
conv_strides_dev
=
make_tuple
(
conv_stride_h
,
conv_stride_w
);
const
auto
conv_dilations_dev
=
make_tuple
(
conv_dilation_h
,
conv_dilation_w
);
const
auto
in_left_pads_dev
=
make_tuple
(
in_left_pad_h
,
in_left_pad_w
);
const
auto
in_right_pads_dev
=
make_tuple
(
in_right_pad_h
,
in_right_pad_w
);
return
make_tuple
(
in_lengths_dev
,
wei_lengths_dev
,
add_lengths_dev
,
out_lengths_dev
,
conv_strides_dev
,
conv_dilations_dev
,
in_left_pads_dev
,
in_right_pads_dev
);
};
#if USE_CONV_FWD_V5R1_NCHWC
if
(
algo
==
ConvForwardAlgo
::
V5R1NCHWC
)
{
const
auto
tmp
=
f_make_for_device_nchwc
();
device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1
<
in_data_t
,
acc_data_t
,
out_data_t
,
activ_type
>
(
tmp
[
I0
],
// in_lengths_dev
tmp
[
I1
],
// wei_lengths_dev
tmp
[
I2
],
// add_lengths_dev
tmp
[
I3
],
// out_lengths_dev
tmp
[
I4
],
// conv_strides_dev
tmp
[
I5
],
// conv_dilations_dev
tmp
[
I6
],
// in_left_pads_dev
tmp
[
I7
],
// in_right_pads_dev
in
,
wei
,
add
,
add_device
,
out_device
,
nrepeat
);
}
#endif
if
(
do_verification
)
{
host_direct_convolution_add_nchwc
(
in
,
wei
,
add
,
add_host
,
out_host
,
make_tuple
(
conv_stride_h
,
conv_stride_w
),
make_tuple
(
conv_dilation_h
,
conv_dilation_w
),
make_tuple
(
in_left_pad_h
,
in_left_pad_w
),
make_tuple
(
in_right_pad_h
,
in_right_pad_w
),
activ_type
);
check_error
(
out_host
,
out_device
);
check_error
(
add_host
,
add_device
);
if
(
do_log
)
{
// LogRangeAsType<float>(std::cout << "in : ", in.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "wei: ", wei.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "out_host : ", out_host.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "out_device: ", out_device.mData, ",") <<
// std::endl;
// LogRangeAsType<float>(std::cout << "add_device: ", add_device.mData, ",") <<
// std::endl;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"add_host: "
,
add_host
.
mData
,
","
)
<<
std
::
endl
;
}
}
}
host/driver_offline/src/conv_fwd_driver_offline_nchwc.cpp
View file @
5b1a9994
...
...
@@ -45,6 +45,8 @@ int main(int argc, char* argv[])
exit
(
1
);
}
constexpr
index_t
activ_type
=
0
;
const
ConvForwardAlgo
algo
=
static_cast
<
ConvForwardAlgo
>
(
std
::
stoi
(
argv
[
1
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
2
]);
const
int
init_method
=
std
::
stoi
(
argv
[
3
]);
...
...
@@ -90,7 +92,9 @@ int main(int argc, char* argv[])
const
bool
do_log
=
std
::
stoi
(
argv
[
4
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
5
]);
#if 1
constexpr
index_t
activ_type
=
0
;
#if 0
constexpr auto N = Number<1>{};
constexpr auto Hi = Number<1080>{};
constexpr auto Wi = Number<1920>{};
...
...
@@ -100,7 +104,7 @@ int main(int argc, char* argv[])
constexpr auto C1 = Number<8>{};
constexpr auto K1 = Number<8>{};
constexpr auto K0 = Number<8>{};
#elif
0
#elif
1
constexpr
auto
N
=
Number
<
1
>
{};
constexpr
auto
Hi
=
Number
<
540
>
{};
constexpr
auto
Wi
=
Number
<
960
>
{};
...
...
@@ -250,8 +254,6 @@ int main(int argc, char* argv[])
in_right_pads_dev
);
};
constexpr
index_t
activ_type
=
0
;
#if USE_CONV_FWD_V5R1_NCHWC
if
(
algo
==
ConvForwardAlgo
::
V5R1NCHWC
)
{
...
...
host/host_tensor/include/host_conv.hpp
View file @
5b1a9994
...
...
@@ -156,6 +156,70 @@ void host_direct_convolution_nchwc(const Tensor<TIn>& in,
out
.
mDesc
.
GetLengths
()[
4
])(
std
::
thread
::
hardware_concurrency
());
}
template
<
typename
TIn
,
typename
TWei
,
typename
TOut
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
host_direct_convolution_add_nchwc
(
const
Tensor
<
TIn
>&
in
,
const
Tensor
<
TWei
>&
wei
,
const
Tensor
<
TOut
>&
add
,
Tensor
<
TOut
>&
add_host
,
Tensor
<
TOut
>&
out_host
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
,
const
ck
::
index_t
activ_type
=
0
)
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
auto
f_nchw
=
[
&
](
auto
n
,
auto
k0
,
auto
ho
,
auto
wo
,
auto
k1
)
{
double
v
=
0
;
for
(
int
c0
=
0
;
c0
<
wei
.
mDesc
.
GetLengths
()[
1
];
++
c0
)
{
for
(
int
c1
=
0
;
c1
<
wei
.
mDesc
.
GetLengths
()[
4
];
++
c1
)
{
for
(
int
y
=
0
;
y
<
wei
.
mDesc
.
GetLengths
()[
2
];
++
y
)
{
int
hi
=
ho
*
conv_strides
[
I0
]
+
y
*
conv_dilations
[
I0
]
-
in_left_pads
[
I0
];
for
(
int
x
=
0
;
x
<
wei
.
mDesc
.
GetLengths
()[
3
];
++
x
)
{
int
wi
=
wo
*
conv_strides
[
I1
]
+
x
*
conv_dilations
[
I1
]
-
in_left_pads
[
I1
];
if
(
hi
>=
0
&&
hi
<
in
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
wi
<
in
.
mDesc
.
GetLengths
()[
3
])
{
v
+=
static_cast
<
const
double
>
(
in
(
n
,
c0
,
hi
,
wi
,
c1
))
*
static_cast
<
const
double
>
(
wei
(
k0
*
out_host
.
mDesc
.
GetLengths
()[
4
]
+
k1
,
c0
,
y
,
x
,
c1
));
}
}
}
}
}
v
=
activ
(
v
,
activ_type
);
out_host
(
n
,
k0
,
ho
,
wo
,
k1
)
=
v
;
add_host
(
n
,
k0
,
ho
,
wo
,
k1
)
=
v
+
add
(
n
,
k0
,
ho
,
wo
,
k1
);
add_host
(
n
,
k0
,
ho
,
wo
+
1
,
k1
)
=
v
+
add
(
n
,
k0
,
ho
,
wo
+
1
,
k1
);
add_host
(
n
,
k0
,
ho
+
1
,
wo
,
k1
)
=
v
+
add
(
n
,
k0
,
ho
+
1
,
wo
,
k1
);
add_host
(
n
,
k0
,
ho
+
1
,
wo
+
1
,
k1
)
=
v
+
add
(
n
,
k0
,
ho
+
1
,
wo
+
1
,
k1
);
};
make_ParallelTensorFunctor
(
f_nchw
,
out_host
.
mDesc
.
GetLengths
()[
0
],
out_host
.
mDesc
.
GetLengths
()[
1
],
out_host
.
mDesc
.
GetLengths
()[
2
],
out_host
.
mDesc
.
GetLengths
()[
3
],
out_host
.
mDesc
.
GetLengths
()[
4
])(
std
::
thread
::
hardware_concurrency
());
}
template
<
typename
TIn
,
typename
TWei
,
typename
TOut
,
typename
InLeftPads
,
typename
InRightPads
>
void
host_winograd_3x3_convolution
(
const
Tensor
<
TIn
>&
in_nchw
,
const
Tensor
<
TWei
>&
wei_kcyx
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment