"...composable_kernel.git" did not exist on "a72a5762fbe6ba834d16f35f2716dda39c4a95bf"
Commit 756b0ca1 authored by Chao Liu's avatar Chao Liu
Browse files

delete obselete files

parent e17c0d80
#ifndef CK_GRIDWISE_OPERATION_KERNEL_WRAPPER
#define CK_GRIDWISE_OPERATION_KERNEL_WRAPPER
template <typename GridwiseOp, typename... Xs>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
run_gridwise_operation(Xs... xs)
{
GridwiseOp{}.Run(xs...);
}
#endif
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_dlops_v1r2.hpp"
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
using namespace ck;
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
constexpr index_t BlockSize = CK_PARAM_BlockSize;
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
constexpr index_t M1PerThread = CK_PARAM_M1PerThread;
constexpr index_t N1PerThread = CK_PARAM_N1PerThread;
constexpr index_t KPerThread = CK_PARAM_KPerThread;
constexpr index_t M1N1ThreadClusterM10 = CK_PARAM_M1N1ThreadClusterM10;
constexpr index_t M1N1ThreadClusterN10 = CK_PARAM_M1N1ThreadClusterN10;
constexpr index_t M1N1ThreadClusterM11 = CK_PARAM_M1N1ThreadClusterM11;
constexpr index_t M1N1ThreadClusterN11 = CK_PARAM_M1N1ThreadClusterN11;
using ABlockTransferThreadSliceLengths_K_M0_M1 =
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K_M0_M1>;
using ABlockTransferThreadClusterLengths_K_M0_M1 =
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K_M0_M1>;
using ABlockTransferThreadClusterArrangeOrder =
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
constexpr index_t ABlockTransferDstScalarPerVector_M1 =
CK_PARAM_ABlockTransferDstScalarPerVector_M1;
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
using BBlockTransferThreadSliceLengths_K_N0_N1 =
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K_N0_N1>;
using BBlockTransferThreadClusterLengths_K_N0_N1 =
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K_N0_N1>;
using BBlockTransferThreadClusterArrangeOrder =
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
constexpr index_t BBlockTransferDstScalarPerVector_N1 =
CK_PARAM_BBlockTransferDstScalarPerVector_N1;
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HAS_MAIN_KBLOCK_LOOP);
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP);
extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare(
int n,
int c,
int hi,
int wi,
int k,
int y,
int x,
int convStrideH,
int convStrideW,
int convDilationY,
int convDilationX,
int leftPadH,
int leftPadW,
int rightPadH,
int rightPadW,
void* p_a_k_m0_m1_grid_desc,
void* p_b_k_n0_n1_grid_desc,
void* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
void* p_cblockid_to_m0_n0_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(n, c, hi, wi));
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(k, c, y, x));
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(n, k, ho, wo));
const auto descs = transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(
wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
make_tuple(convStrideH, convStrideW),
make_tuple(convDilationY, convDilationX),
make_tuple(leftPadH, leftPadW),
make_tuple(rightPadH, rightPadW));
const auto a_k_m_grid_desc = descs[I0];
const auto b_k_n_grid_desc = descs[I1];
const auto c_m_n_grid_desc = descs[I2];
using AKMGridDesc = decltype(a_k_m_grid_desc);
using BKNGridDesc = decltype(b_k_n_grid_desc);
using CMNGridDesc = decltype(c_m_n_grid_desc);
using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{})));
using BGridStepHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})));
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{})));
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
using GridwiseGemm =
GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperationEnum_t::Set, /* ToDo tunable */
AKMGridDesc,
BKNGridDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11,
ABlockTransferThreadSliceLengths_K_M0_M1,
ABlockTransferThreadClusterLengths_K_M0_M1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M1,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K_N0_N1,
BBlockTransferThreadClusterLengths_K_N0_N1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N1,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks>;
auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
auto c_m0_m10_m11_n0_n10_n11_grid_desc =
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
auto cblockid_to_m0_n0_block_cluster_adaptor =
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
if(hipThreadIdx_x == 0)
{
*static_cast<decltype(a_k_m0_m1_grid_desc)*>(p_a_k_m0_m1_grid_desc) = a_k_m0_m1_grid_desc;
*static_cast<decltype(b_k_n0_n1_grid_desc)*>(p_b_k_n0_n1_grid_desc) = b_k_n0_n1_grid_desc;
*static_cast<decltype(c_m0_m10_m11_n0_n10_n11_grid_desc)*>(
p_c_m0_m10_m11_n0_n10_n11_grid_desc) = c_m0_m10_m11_n0_n10_n11_grid_desc;
*static_cast<decltype(cblockid_to_m0_n0_block_cluster_adaptor)*>(
p_cblockid_to_m0_n0_block_cluster_adaptor) = cblockid_to_m0_n0_block_cluster_adaptor;
};
};
extern "C" __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_k_m0_m1_grid_desc,
const void CONSTANT* p_b_k_n0_n1_grid_desc,
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
const void CONSTANT* p_cblockid_to_m0_n0_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto in_n_c_hi_wi_desc =
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
constexpr auto wei_k_c_y_x_desc =
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3));
constexpr auto out_n_k_ho_wo_desc =
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
constexpr auto descs =
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
make_tuple(1, 1),
make_tuple(1, 1),
make_tuple(1, 1),
make_tuple(1, 1));
constexpr auto a_k_m_grid_desc = descs[I0];
constexpr auto b_k_n_grid_desc = descs[I1];
constexpr auto c_m_n_grid_desc = descs[I2];
using AKMGridDesc = decltype(a_k_m_grid_desc);
using BKNGridDesc = decltype(b_k_n_grid_desc);
using CMNGridDesc = decltype(c_m_n_grid_desc);
using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{})));
using BGridStepHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})));
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{})));
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
using GridwiseGemm =
GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperationEnum_t::Set, /* ToDo tunable */
AKMGridDesc,
BKNGridDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11,
ABlockTransferThreadSliceLengths_K_M0_M1,
ABlockTransferThreadClusterLengths_K_M0_M1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M1,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K_N0_N1,
BBlockTransferThreadClusterLengths_K_N0_N1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N1,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks>;
constexpr auto a_k_m0_m1_grid_desc_tmp =
GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
constexpr auto b_k_n0_n1_grid_desc_tmp =
GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
constexpr auto c_m0_m10_m11_n0_n10_n11_grid_desc_tmp =
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
constexpr auto cblockid_to_m0_n0_block_cluster_adaptor_tmp =
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc_tmp);
using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc_tmp);
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc_tmp);
using CBlockIdToM0N0BlockClusterAdaptor = decltype(cblockid_to_m0_n0_block_cluster_adaptor_tmp);
const auto a_k_m0_m1_grid_desc =
*reinterpret_cast<const AKM0M1GridDesc*>((const void*)p_a_k_m0_m1_grid_desc);
const auto b_k_n0_n1_grid_desc =
*reinterpret_cast<const BKN0N1GridDesc*>((const void*)p_b_k_n0_n1_grid_desc);
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
(const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc);
const auto cblockid_to_m0_n0_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
(const void*)p_cblockid_to_m0_n0_block_cluster_adaptor);
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
cblockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
};
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
using namespace ck;
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
constexpr index_t BlockSize = CK_PARAM_BlockSize;
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
constexpr index_t MPerWave = CK_PARAM_MPerWave;
constexpr index_t NPerWave = CK_PARAM_NPerWave;
constexpr index_t MRepeat = CK_PARAM_MRepeat;
constexpr index_t NRepeat = CK_PARAM_NRepeat;
constexpr index_t K1 = CK_PARAM_K1;
using ABlockTransferThreadSliceLengths_K0_M_K1 =
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1>;
using ABlockTransferThreadClusterLengths_K0_M_K1 =
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1>;
using ABlockTransferThreadClusterArrangeOrder =
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
constexpr index_t ABlockTransferDstScalarPerVector_K1 =
CK_PARAM_ABlockTransferDstScalarPerVector_K1;
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
using BBlockTransferThreadSliceLengths_K0_N_K1 =
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1>;
using BBlockTransferThreadClusterLengths_K0_N_K1 =
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1>;
using BBlockTransferThreadClusterArrangeOrder =
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
constexpr index_t BBlockTransferDstScalarPerVector_K1 =
CK_PARAM_BBlockTransferDstScalarPerVector_K1;
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare(
int n,
int c,
int hi,
int wi,
int k,
int y,
int x,
int convStrideH,
int convStrideW,
int convDilationY,
int convDilationX,
int leftPadH,
int leftPadW,
int rightPadH,
int rightPadW,
void* p_a_k0_m_k1_grid_desc,
void* p_b_k0_n_k1_grid_desc,
void* p_c_m0_m1_m2_n_grid_desc,
void* p_cblockid_to_m0_n0_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(n, c, hi, wi));
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(k, c, y, x));
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(n, k, ho, wo));
const auto descs = transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
make_tuple(convStrideH, convStrideW),
make_tuple(convDilationY, convDilationX),
make_tuple(leftPadH, leftPadW),
make_tuple(rightPadH, rightPadW),
Number<K1>{});
const auto a_k0_m_k1_grid_desc = descs[I0];
const auto b_k0_n_k1_grid_desc = descs[I1];
const auto c_m_n_grid_desc = descs[I2];
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc);
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
using CMNGridDesc = decltype(c_m_n_grid_desc);
using AGridStepHacks = decltype(make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
using BGridStepHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{})));
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
using GridwiseGemm =
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperationEnum_t::Set,
AK0MK1GridDesc,
BK0NK1GridDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWave,
NPerWave,
K1,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks,
false>;
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
auto cblockid_to_m0_n0_block_cluster_adaptor =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
if(hipThreadIdx_x == 0)
{
*static_cast<remove_cv_t<decltype(a_k0_m_k1_grid_desc)>*>(p_a_k0_m_k1_grid_desc) =
a_k0_m_k1_grid_desc;
*static_cast<remove_cv_t<decltype(b_k0_n_k1_grid_desc)>*>(p_b_k0_n_k1_grid_desc) =
b_k0_n_k1_grid_desc;
*static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) =
c_m0_m1_m2_n_grid_desc;
*static_cast<decltype(cblockid_to_m0_n0_block_cluster_adaptor)*>(
p_cblockid_to_m0_n0_block_cluster_adaptor) = cblockid_to_m0_n0_block_cluster_adaptor;
}
};
extern "C" __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_k0_m_k1_grid_desc,
const void CONSTANT* p_b_k0_n_k1_grid_desc,
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
const void CONSTANT* p_cblockid_to_m0_n0_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto in_n_c_hi_wi_desc =
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
constexpr auto wei_k_c_y_x_desc =
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3));
constexpr auto out_n_k_ho_wo_desc =
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
constexpr auto descs =
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
make_tuple(1, 1),
make_tuple(1, 1),
make_tuple(1, 1),
make_tuple(1, 1),
Number<K1>{});
constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0];
constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1];
constexpr auto c_m_n_grid_desc = descs[I2];
using AGridStepHacks = decltype(make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
using BGridStepHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{})));
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp);
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
using CMNGridDesc = decltype(c_m_n_grid_desc);
using GridwiseGemm =
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperationEnum_t::Set,
AK0MK1GridDesc,
BK0NK1GridDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWave,
NPerWave,
K1,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks,
false>;
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
constexpr auto cblockid_to_m0_n0_block_cluster_adaptor_tmp =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp);
using CBlockIdToM0N0BlockClusterAdaptor = decltype(cblockid_to_m0_n0_block_cluster_adaptor_tmp);
const auto a_k0_m_k1_grid_desc =
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
const auto b_k0_n_k1_grid_desc =
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
const auto c_m0_m1_m2_n_grid_desc =
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
const auto cblockid_to_m0_n0_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
(const void*)p_cblockid_to_m0_n0_block_cluster_adaptor);
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc,
cblockid_to_m0_n0_block_cluster_adaptor);
};
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
using namespace ck;
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
constexpr index_t BlockSize = CK_PARAM_BlockSize;
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
constexpr index_t MPerWave = CK_PARAM_MPerWave;
constexpr index_t NPerWave = CK_PARAM_NPerWave;
constexpr index_t MRepeat = CK_PARAM_MRepeat;
constexpr index_t NRepeat = CK_PARAM_NRepeat;
constexpr index_t K1 = CK_PARAM_K1;
using ABlockTransferThreadSliceLengths_K0_M_K1 =
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1>;
using ABlockTransferThreadClusterLengths_K0_M_K1 =
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1>;
using ABlockTransferThreadClusterArrangeOrder =
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
constexpr index_t ABlockTransferDstScalarPerVector_K1 =
CK_PARAM_ABlockTransferDstScalarPerVector_K1;
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
using BBlockTransferThreadSliceLengths_K0_N_K1 =
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1>;
using BBlockTransferThreadClusterLengths_K0_N_K1 =
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1>;
using BBlockTransferThreadClusterArrangeOrder =
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
constexpr index_t BBlockTransferDstScalarPerVector_K1 =
CK_PARAM_BBlockTransferDstScalarPerVector_K1;
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare(
int n,
int hi,
int wi,
int c,
int k,
int y,
int x,
int convStrideH,
int convStrideW,
int convDilationY,
int convDilationX,
int leftPadH,
int leftPadW,
int rightPadH,
int rightPadW,
void* p_a_k0_m_k1_grid_desc,
void* p_b_k0_n_k1_grid_desc,
void* p_c_m0_m1_m2_n_grid_desc,
void* p_cblockid_to_m0_n0_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(make_tuple(n, hi, wi, c));
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(make_tuple(k, y, x, c));
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(make_tuple(n, ho, wo, k));
const auto descs = transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk(
in_n_hi_wi_c_desc,
wei_k_y_x_c_desc,
out_n_ho_wo_k_desc,
make_tuple(convStrideH, convStrideW),
make_tuple(convDilationY, convDilationX),
make_tuple(leftPadH, leftPadW),
make_tuple(rightPadH, rightPadW),
Number<K1>{});
const auto a_k0_m_k1_grid_desc = descs[I0];
const auto b_k0_n_k1_grid_desc = descs[I1];
const auto c_m_n_grid_desc = descs[I2];
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc);
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
using CMNGridDesc = decltype(c_m_n_grid_desc);
using BGridStepHacks = decltype(make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
using AGridStepHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{})));
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
using GridwiseGemm =
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperationEnum_t::Set,
AK0MK1GridDesc,
BK0NK1GridDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWave,
NPerWave,
K1,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks,
false>;
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
auto cblockid_to_m0_n0_block_cluster_adaptor =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
if(hipThreadIdx_x == 0)
{
*static_cast<remove_cv_t<decltype(a_k0_m_k1_grid_desc)>*>(p_a_k0_m_k1_grid_desc) =
a_k0_m_k1_grid_desc;
*static_cast<remove_cv_t<decltype(b_k0_n_k1_grid_desc)>*>(p_b_k0_n_k1_grid_desc) =
b_k0_n_k1_grid_desc;
*static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) =
c_m0_m1_m2_n_grid_desc;
*static_cast<decltype(cblockid_to_m0_n0_block_cluster_adaptor)*>(
p_cblockid_to_m0_n0_block_cluster_adaptor) = cblockid_to_m0_n0_block_cluster_adaptor;
}
};
extern "C" __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_k0_m_k1_grid_desc,
const void CONSTANT* p_b_k0_n_k1_grid_desc,
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
const void CONSTANT* p_cblockid_to_m0_n0_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto in_n_hi_wi_c_desc =
make_naive_tensor_descriptor_packed(make_tuple(256, 28, 28, 256));
constexpr auto wei_k_y_x_c_desc =
make_naive_tensor_descriptor_packed(make_tuple(256, 3, 3, 256));
constexpr auto out_n_ho_wo_k_desc =
make_naive_tensor_descriptor_packed(make_tuple(256, 28, 28, 256));
constexpr auto descs =
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk(in_n_hi_wi_c_desc,
wei_k_y_x_c_desc,
out_n_ho_wo_k_desc,
make_tuple(1, 1),
make_tuple(1, 1),
make_tuple(1, 1),
make_tuple(1, 1),
Number<K1>{});
constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0];
constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1];
constexpr auto c_m_n_grid_desc = descs[I2];
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp);
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
using CMNGridDesc = decltype(c_m_n_grid_desc);
using BGridStepHacks = decltype(make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
using AGridStepHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{})));
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
using GridwiseGemm =
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperationEnum_t::Set,
AK0MK1GridDesc,
BK0NK1GridDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWave,
NPerWave,
K1,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks,
false>;
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
constexpr auto cblockid_to_m0_n0_block_cluster_adaptor_tmp =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp);
using CBlockIdToM0N0BlockClusterAdaptor = decltype(cblockid_to_m0_n0_block_cluster_adaptor_tmp);
const auto a_k0_m_k1_grid_desc =
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
const auto b_k0_n_k1_grid_desc =
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
const auto c_m0_m1_m2_n_grid_desc =
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
const auto cblockid_to_m0_n0_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
(const void*)p_cblockid_to_m0_n0_block_cluster_adaptor);
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc,
cblockid_to_m0_n0_block_cluster_adaptor);
};
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_contraction_dlops_v1r2.hpp"
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
using namespace ck;
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
constexpr index_t BlockSize = CK_PARAM_BlockSize;
constexpr auto GN0 = Number<CK_PARAM_GN0>{};
constexpr auto GK1 = Number<CK_PARAM_GK1>{};
constexpr index_t GM1PerBlockGM11 = CK_PARAM_GM1PerBlockGM11;
constexpr index_t GN1PerBlockGN11 = CK_PARAM_GN1PerBlockGN11;
constexpr index_t GK0PerBlock = CK_PARAM_GK0PerBlock;
constexpr index_t BM1PerThreadBM11 = CK_PARAM_BM1PerThreadBM11;
constexpr index_t BN1PerThreadBN11 = CK_PARAM_BN1PerThreadBN11;
constexpr index_t BK0PerThread = CK_PARAM_BK0PerThread;
using BM10BN10ThreadClusterBM10Xs = Sequence<CK_PARAM_BM10BN10ThreadClusterBM10Xs>;
using BM10BN10ThreadClusterBN10Xs = Sequence<CK_PARAM_BM10BN10ThreadClusterBN10Xs>;
using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 =
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1>;
using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 =
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1>;
using ABlockTransferThreadClusterArrangeOrder = Sequence<1, 2, 3, 0, 4>;
using ABlockTransferSrcAccessOrder = Sequence<3, 2, 1, 0, 4>;
using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 =
Sequence<CK_PARAM_ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1>;
using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 =
Sequence<CK_PARAM_ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1>;
using ABlockTransferSrcVectorTensorContiguousDimOrder = Sequence<0, 1, 2, 3, 4>;
using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 =
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1>;
using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 =
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1>;
using BBlockTransferThreadClusterArrangeOrder = Sequence<0, 4, 1, 2, 3>;
using BBlockTransferSrcAccessOrder = Sequence<4, 3, 2, 0, 1>;
using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 =
Sequence<CK_PARAM_BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1>;
using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 =
Sequence<CK_PARAM_BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1>;
using BBlockTransferSrcVectorTensorContiguousDimOrder = Sequence<0, 1, 2, 3, 4>;
using CThreadTransferSrcDstAccessOrder = Sequence<3, 4, 5, 0, 1, 2>;
constexpr index_t CThreadTransferSrcDstVectorDim = 5;
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HasMainKBlockLoop);
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HasDoubleTailKBlockLoop);
extern "C" __global__ void
convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(int N_,
int C_,
int Hi_,
int Wi_,
int K_,
int Y_,
int X_,
int ConvStrideH_,
int ConvStrideW_,
int ConvDilationH_,
int ConvDilationW_,
int InLeftPadH_,
int InLeftPadW_,
int InRightPadH_,
int InRightPadW_,
void* p_desc_tuple)
{
index_t N = static_cast<index_t>(N_);
index_t C = static_cast<index_t>(C_);
index_t Hi = static_cast<index_t>(Hi_);
index_t Wi = static_cast<index_t>(Wi_);
index_t K = static_cast<index_t>(K_);
index_t Y = static_cast<index_t>(Y_);
index_t X = static_cast<index_t>(X_);
index_t ConvStrideH = static_cast<index_t>(ConvStrideH_);
index_t ConvStrideW = static_cast<index_t>(ConvStrideW_);
index_t ConvDilationH = static_cast<index_t>(ConvDilationH_);
index_t ConvDilationW = static_cast<index_t>(ConvDilationW_);
index_t InLeftPadH = static_cast<index_t>(InLeftPadH_);
index_t InLeftPadW = static_cast<index_t>(InLeftPadW_);
index_t InRightPadH = static_cast<index_t>(InRightPadH_);
index_t InRightPadW = static_cast<index_t>(InRightPadW_);
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
const index_t Ho =
(Hi + InLeftPadH + InRightPadH - ConvDilationH * (Y - 1) - 1) / ConvStrideH + 1;
const index_t Wo =
(Wi + InLeftPadW + InRightPadW - ConvDilationW * (X - 1) - 1) / ConvStrideW + 1;
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(N, C, Hi, Wi));
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(K, C, Y, X));
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho, Wo));
const auto descs = transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
make_tuple(ConvStrideH, ConvStrideW),
make_tuple(ConvDilationH, ConvDilationW),
make_tuple(InLeftPadH, InLeftPadW),
make_tuple(InRightPadH, InRightPadW),
GN0,
GK1);
const auto a_grid_desc_gk0_gm0_gm1_gk1 = descs[I0];
const auto b_grid_desc_gk0_gn0_gn1_gk1 = descs[I1];
const auto c_grid_desc_gm0_gm1_gn0_gn1 = descs[I2];
using AGridDesc_GK0_GM0_GM1_GK1 = decltype(a_grid_desc_gk0_gm0_gm1_gk1);
using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1);
using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1);
using AGridStepHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11
Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
using BGridStepHacks = decltype(make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
using CGridStepHacks = decltype(make_tuple(
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
using BGridMoveSliceWindowStepHacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
using GridwiseContraction =
GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperationEnum_t::Set,
AGridDesc_GK0_GM0_GM1_GK1,
BGridDesc_GK0_GN0_GN1_GK1,
CGridDesc_GM0_GM1_GN0_GN1,
GM1PerBlockGM11,
GN1PerBlockGN11,
GK0PerBlock,
BM1PerThreadBM11,
BN1PerThreadBN11,
BK0PerThread,
BM10BN10ThreadClusterBM10Xs,
BM10BN10ThreadClusterBN10Xs,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks>;
if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
{
auto desc_tuple =
make_tuple(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
a_grid_desc_gk0_gm0_gm1_gk1),
GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(
b_grid_desc_gk0_gn0_gn1_gk1),
GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
c_grid_desc_gm0_gm1_gn0_gn1),
GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
c_grid_desc_gm0_gm1_gn0_gn1));
*static_cast<decltype(desc_tuple)*>(p_desc_tuple) = desc_tuple;
}
};
extern "C" __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_desc_tuple)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_n_c_hi_wi_desc =
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
constexpr auto wei_k_c_y_x_desc =
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3));
constexpr auto out_n_k_ho_wo_desc =
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
constexpr auto descs =
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
make_tuple(1, 1),
make_tuple(1, 1),
make_tuple(1, 1),
make_tuple(1, 1),
GN0,
GK1);
constexpr auto a_grid_desc_gk0_gm0_gm1_gk1 = descs[I0];
constexpr auto b_grid_desc_gk0_gn0_gn1_gk1 = descs[I1];
constexpr auto c_grid_desc_gm0_gm1_gn0_gn1 = descs[I2];
using AGridDesc_GK0_GM0_GM1_GK1 = decltype(a_grid_desc_gk0_gm0_gm1_gk1);
using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1);
using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1);
using AGridStepHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11
Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
using BGridStepHacks = decltype(make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
using CGridStepHacks = decltype(make_tuple(
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
using BGridMoveSliceWindowStepHacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
using GridwiseContraction =
GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperationEnum_t::Set,
AGridDesc_GK0_GM0_GM1_GK1,
BGridDesc_GK0_GN0_GN1_GK1,
CGridDesc_GM0_GM1_GN0_GN1,
GM1PerBlockGM11,
GN1PerBlockGN11,
GK0PerBlock,
BM1PerThreadBM11,
BN1PerThreadBN11,
BK0PerThread,
BM10BN10ThreadClusterBM10Xs,
BM10BN10ThreadClusterBN10Xs,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks>;
using AGridDesc_GK0_GM0_GM10_GM11_GK1 =
decltype(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
a_grid_desc_gk0_gm0_gm1_gk1));
using BGridDesc_GK0_GN0_GN10_GN11_GK1 =
decltype(GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(
b_grid_desc_gk0_gn0_gn1_gk1));
using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 =
decltype(GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
c_grid_desc_gm0_gm1_gn0_gn1));
using CGridBlockCluster_BlockId_To_GM10_GN10 =
decltype(GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
c_grid_desc_gm0_gm1_gn0_gn1));
using DescTuple = decltype(make_tuple(AGridDesc_GK0_GM0_GM10_GM11_GK1{},
BGridDesc_GK0_GN0_GN10_GN11_GK1{},
CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1{},
CGridBlockCluster_BlockId_To_GM10_GN10{}));
const auto desc_tuple =
*reinterpret_cast<const DescTuple*>(cast_pointer_to_generic_address_space(p_desc_tuple));
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = desc_tuple[I0];
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = desc_tuple[I1];
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = desc_tuple[I2];
const auto c_grid_block_cluster_blockid_to_gm10_gn10 = desc_tuple[I3];
constexpr index_t shared_block_size =
GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseContraction::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_block_cluster_blockid_to_gm10_gn10,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
};
#ifndef CONV_IGEMM_FWD_V6R1_DLOPS_NCHW_KCYX_NKHW_HPP
#define CONV_IGEMM_FWD_V6R1_DLOPS_NCHW_KCYX_NKHW_HPP
#include <numeric>
#include <sstream>
namespace ck {
namespace driver {
struct CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw
{
auto GetCompileParameterString() const
{
auto param = std::stringstream();
// clang-format off
param <<
" -DCK_PARAM_ABDataTypeEnum=" <<
ABDataTypeEnum <<
" -DCK_PARAM_AccDataTypeEnum=" <<
AccDataTypeEnum <<
" -DCK_PARAM_CDataTypeEnum=" <<
CDataTypeEnum <<
" -DCK_PARAM_BlockSize=" <<
BlockSize <<
" -DCK_PARAM_GN0=" <<
GN0 <<
" -DCK_PARAM_GK1=" <<
GK1 <<
" -DCK_PARAM_GM1PerBlockGM11="
<< GM1PerBlockGM11 <<
" -DCK_PARAM_GN1PerBlockGN11=" <<
GN1PerBlockGN11 <<
" -DCK_PARAM_GK0PerBlock=" <<
GK0PerBlock <<
" -DCK_PARAM_BM1PerThreadBM11=" <<
BM1PerThreadBM11 <<
" -DCK_PARAM_BN1PerThreadBN11=" <<
BN1PerThreadBN11 <<
" -DCK_PARAM_BK0PerThread=" <<
BK0PerThread <<
" -DCK_PARAM_BM10BN10ThreadClusterBM10Xs=" <<
BM10BN10ThreadClusterBM10Xs[0] << "," <<
BM10BN10ThreadClusterBM10Xs[1] <<
" -DCK_PARAM_BM10BN10ThreadClusterBN10Xs=" <<
BM10BN10ThreadClusterBN10Xs[0] << "," <<
BM10BN10ThreadClusterBN10Xs[1] <<
" -DCK_PARAM_ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1=" <<
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[0] << "," <<
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[1] << "," <<
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[2] << "," <<
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[3] << "," <<
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[4] <<
" -DCK_PARAM_ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1=" <<
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[0] << "," <<
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[1] << "," <<
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[2] << "," <<
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[3] << "," <<
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[4] <<
" -DCK_PARAM_ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1=" <<
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0] << "," <<
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1] << "," <<
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2] << "," <<
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3] << "," <<
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4] <<
" -DCK_PARAM_ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1=" <<
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0] << "," <<
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1] << "," <<
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2] << "," <<
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3] << "," <<
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4] <<
" -DCK_PARAM_BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1=" <<
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[0] << "," <<
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[1] << "," <<
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[2] << "," <<
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[3] << "," <<
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[4] <<
" -DCK_PARAM_BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1=" <<
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[0] << "," <<
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[1] << "," <<
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[2] << "," <<
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[3] << "," <<
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[4] <<
" -DCK_PARAM_BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1=" <<
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0] << "," <<
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1] << "," <<
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2] << "," <<
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3] << "," <<
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4] <<
" -DCK_PARAM_BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1=" <<
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0] << "," <<
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1] << "," <<
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2] << "," <<
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3] << "," <<
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4] <<
" -DCK_PARAM_CThreadTransferDstScalarPerVector=" <<
CThreadTransferDstScalarPerVector <<
" -DCK_PARAM_HasMainKBlockLoop=" <<
static_cast<int>(HasMainKBlockLoop) <<
" -DCK_PARAM_HasDoubleTailKBlockLoop=" <<
static_cast<int>(HasDoubleTailKBlockLoop);
// clang-format on
return param.str();
}
ck::DataTypeEnum_t ABDataTypeEnum = ck::DataTypeEnum_t::Unknown;
ck::DataTypeEnum_t AccDataTypeEnum = ck::DataTypeEnum_t::Unknown;
ck::DataTypeEnum_t CDataTypeEnum = ck::DataTypeEnum_t::Unknown;
int BlockSize = -1;
int GN0 = -1;
int GK1 = -1;
int GM1PerBlockGM11 = -1;
int GN1PerBlockGN11 = -1;
int GK0PerBlock = -1;
int BM1PerThreadBM11 = -1;
int BN1PerThreadBN11 = -1;
int BK0PerThread = -1;
std::array<int, 2> BM10BN10ThreadClusterBM10Xs = {-1, -1};
std::array<int, 2> BM10BN10ThreadClusterBN10Xs = {-1, -1};
std::array<int, 5> ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = {
-1, -1, -1, -1, -1};
std::array<int, 5> ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = {
-1, -1, -1, -1, -1};
std::array<int, 5> ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = {
-1, -1, -1, -1, -1};
std::array<int, 5> ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = {
-1, -1, -1, -1, -1};
std::array<int, 5> BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = {
-1, -1, -1, -1, -1};
std::array<int, 5> BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = {
-1, -1, -1, -1, -1};
std::array<int, 5> BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = {
-1, -1, -1, -1, -1};
std::array<int, 5> BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = {
-1, -1, -1, -1, -1};
int CThreadTransferDstScalarPerVector = -1;
bool HasMainKBlockLoop = false;
bool HasDoubleTailKBlockLoop = false;
};
struct TunableConvIgemmFwdV6r1DlopsNchwKcyxNkhw
{
ck::DataTypeEnum_t ABDataTypeEnum;
ck::DataTypeEnum_t CDataTypeEnum;
int BlockSize;
int GN0;
int GK1;
int GM1PerBlockGM11;
int GN1PerBlockGN11;
int GK0PerBlock;
int BM1PerThreadBM11;
int BN1PerThreadBN11;
int BK0PerThread;
std::array<int, 2> BM10BN10ThreadClusterBM10Xs;
std::array<int, 2> BM10BN10ThreadClusterBN10Xs;
std::array<int, 5> ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1;
std::array<int, 5> ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1;
std::array<int, 5> ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1;
std::array<int, 5> ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1;
std::array<int, 5> BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1;
std::array<int, 5> BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1;
std::array<int, 5> BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1;
std::array<int, 5> BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1;
};
inline static auto generate_tunable_list_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw()
{
constexpr auto f32 = ck::DataTypeEnum_t::Float;
constexpr auto f16 = ck::DataTypeEnum_t::Half;
constexpr auto i8 = ck::DataTypeEnum_t::Int8;
return std::vector<TunableConvIgemmFwdV6r1DlopsNchwKcyxNkhw>{
// clang-format off
// fp32
{f32, f32, 256, 1, 1, 128, 128, 16, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 2, 1}, {4, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}},
{f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}},
{f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 2, 1}, {1, 1, 1, 4, 1}},
{f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}},
{f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {4, 1, 1, 1, 1}, { 2, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
{f32, f32, 256, 2, 1, 128, 64, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 2, 1, 1, 1}, { 4, 1, 1, 64, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
{f32, f32, 256, 4, 1, 128, 32, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 4, 1, 1, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
{f32, f32, 256, 8, 1, 128, 16, 16, 4, 4, 1, {8, 2}, {8, 2}, {8, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 8, 1, 1, 1}, {16, 1, 1, 16, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
{f32, f32, 128, 1, 1, 64, 128, 8, 4, 4, 1, {4, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {8, 1, 1, 1, 1}, { 1, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
// fp16
{f16, f16, 256, 1, 2, 128, 128, 16, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 2, 2}, {4, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}},
{f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}},
{f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 2, 1}, {1, 1, 1, 4, 1}},
{f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}},
{f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {4, 1, 1, 1, 2}, { 2, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
{f16, f16, 256, 2, 2, 128, 64, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 2, 1, 1, 2}, { 4, 1, 1, 64, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
{f16, f16, 256, 4, 2, 128, 32, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 4, 1, 1, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
{f16, f16, 256, 8, 2, 128, 16, 16, 4, 4, 1, {8, 2}, {8, 2}, {8, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 8, 1, 1, 2}, {16, 1, 1, 16, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
{f16, f16, 128, 1, 2, 64, 128, 8, 4, 4, 1, {4, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {8, 1, 1, 1, 2}, { 1, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
// i8
{ i8, i8, 256, 1, 4, 128, 128, 16, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 2, 4}, {4, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}},
{ i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}},
{ i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 2, 1}, {1, 1, 1, 4, 1}},
{ i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}},
{ i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {4, 1, 1, 1, 4}, { 2, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
{ i8, i8, 256, 2, 4, 128, 64, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 2, 1, 1, 4}, { 4, 1, 1, 64, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
{ i8, i8, 256, 4, 4, 128, 32, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 4, 1, 1, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
{ i8, i8, 256, 8, 4, 128, 16, 16, 4, 4, 1, {8, 2}, {8, 2}, {8, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 8, 1, 1, 4}, {16, 1, 1, 16, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
{ i8, i8, 128, 1, 4, 64, 128, 8, 4, 4, 1, {4, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {8, 1, 1, 1, 4}, { 1, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}
// clang-format on
};
}
// TODO make this common interface and write specs for it
struct ConvIgemmFwdV6r1DlopsNchwKcyxNkhw
{
static auto
CalculateCompileParameterBasedOnTunable(const ConvolutionProblemDescriptor& conv_problem_desc,
const TunableConvIgemmFwdV6r1DlopsNchwKcyxNkhw& tunable)
{
const int C = conv_problem_desc.C;
const int Y = conv_problem_desc.Y;
const int X = conv_problem_desc.X;
const int Ho = conv_problem_desc.Ho;
const int Wo = conv_problem_desc.Wo;
if(!(conv_problem_desc.InDataTypeEnum == tunable.ABDataTypeEnum &&
conv_problem_desc.WeiDataTypeEnum == tunable.ABDataTypeEnum &&
conv_problem_desc.OutDataTypeEnum == tunable.CDataTypeEnum))
return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false);
const auto ABDataTypeEnum = conv_problem_desc.InDataTypeEnum;
const auto CDataTypeEnum = conv_problem_desc.OutDataTypeEnum;
DataTypeEnum_t AccDataTypeEnum;
if(ABDataTypeEnum == DataTypeEnum_t::Float || ABDataTypeEnum == DataTypeEnum_t::Half)
{
AccDataTypeEnum = DataTypeEnum_t::Float;
}
else if(ABDataTypeEnum == DataTypeEnum_t::Int8)
{
AccDataTypeEnum = DataTypeEnum_t::Int32;
}
else
{
return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false);
}
const int BlockSize = tunable.BlockSize;
const int GN0 = tunable.GN0;
const int GK1 = tunable.GK1;
const int GM11 = tunable.GM1PerBlockGM11;
const int GN11 = tunable.GN1PerBlockGN11;
const int GK0PerBlock = tunable.GK0PerBlock;
const int BM11 = tunable.BM1PerThreadBM11;
const int BN11 = tunable.BN1PerThreadBN11;
const int BK0PerThread = tunable.BK0PerThread;
const auto BM10BN10ThreadClusterBM10Xs = tunable.BM10BN10ThreadClusterBM10Xs;
const auto BM10BN10ThreadClusterBN10Xs = tunable.BM10BN10ThreadClusterBN10Xs;
const auto ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 =
tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1;
const auto ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 =
tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1;
const auto ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 =
tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1;
const auto ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 =
tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1;
const auto BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 =
tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1;
const auto BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 =
tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1;
const auto BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 =
tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1;
const auto BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 =
tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1;
// C threadwise copy: {BN11} or {BN} or {BN1} or {GN11} is Dst vector dim
const int CThreadTransferDstScalarPerVector = gcd(4, GN11, BN11, Ho * Wo);
const int C0 = GK1;
if(!(C % C0 == 0))
return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false);
const int C1 = C / C0;
const int GK0 = C1 * Y * X;
if(!(GK0 % GK0PerBlock == 0))
return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false);
const bool HasMainKBlockLoop = ((GK0 + GK0PerBlock) / (2 * GK0PerBlock) > 1);
const bool HasDoubleTailKBlockLoop = ((GK0 / GK0PerBlock) % 2 == 0);
return std::make_tuple(
CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{
ABDataTypeEnum,
AccDataTypeEnum,
CDataTypeEnum,
BlockSize,
GN0,
GK1,
GM11,
GN11,
GK0PerBlock,
BM11,
BN11,
BK0PerThread,
BM10BN10ThreadClusterBM10Xs,
BM10BN10ThreadClusterBN10Xs,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
CThreadTransferDstScalarPerVector,
HasMainKBlockLoop,
HasDoubleTailKBlockLoop},
true);
}
static auto GetDefaultCompileParameter(const ConvolutionProblemDescriptor& conv_problem_desc)
{
for(const auto& tunable : generate_tunable_list_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw())
{
CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param{};
bool found = false;
std::tie(compile_param, found) =
CalculateCompileParameterBasedOnTunable(conv_problem_desc, tunable);
if(found && IsValidCompileParameter(conv_problem_desc, compile_param))
return std::make_tuple(compile_param, true);
}
return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false);
}
static bool IsApplicable(const ConvolutionProblemDescriptor& conv_problem_desc)
{
bool found = false;
std::tie(std::ignore, found) = GetDefaultCompileParameter(conv_problem_desc);
return found;
}
static bool
IsValidCompileParameter(const ConvolutionProblemDescriptor& conv_problem_desc,
const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param)
{
const int N = conv_problem_desc.N;
const int K = conv_problem_desc.K;
const int C = conv_problem_desc.C;
const int Y = conv_problem_desc.Y;
const int X = conv_problem_desc.X;
const int Ho = conv_problem_desc.Ho;
const int Wo = conv_problem_desc.Wo;
const int GK1 = compile_param.GK1;
const int GN0 = compile_param.GN0;
const int GM11 = compile_param.GM1PerBlockGM11;
const int GN11 = compile_param.GN1PerBlockGN11;
const int BM11 = compile_param.BM1PerThreadBM11;
const int BN11 = compile_param.BN1PerThreadBN11;
const int C0 = GK1;
const int N0 = GN0;
if(!(C % C0 == 0))
return false;
const int C1 = C / C0;
if(!(N % N0 == 0))
return false;
const int N1 = N / N0;
const int GM0 = 1;
const int GM1 = K;
const int GN1 = N1 * Ho * Wo;
const int GK0 = C1 * Y * X;
// check data type
{
if(!(conv_problem_desc.InDataTypeEnum == conv_problem_desc.WeiDataTypeEnum &&
conv_problem_desc.InDataTypeEnum == compile_param.ABDataTypeEnum))
return false;
if(compile_param.ABDataTypeEnum == DataTypeEnum_t::Float ||
compile_param.ABDataTypeEnum == DataTypeEnum_t::Half)
{
if(!(compile_param.AccDataTypeEnum == DataTypeEnum_t::Float))
return false;
}
else if(compile_param.ABDataTypeEnum == DataTypeEnum_t::Int8)
{
if(!(compile_param.AccDataTypeEnum == DataTypeEnum_t::Int32))
return false;
}
}
// check gridwise contraction
{
if(!(GM1 % GM11 == 0 && GN1 % GN11 == 0 && GK0 % compile_param.GK0PerBlock == 0))
return false;
const bool has_main_k_block_loop =
((GK0 + compile_param.GK0PerBlock) / (2 * compile_param.GK0PerBlock) > 1);
const bool has_double_tail_k_block_loop = ((GK0 / compile_param.GK0PerBlock) % 2 == 0);
if(!(has_main_k_block_loop == compile_param.HasMainKBlockLoop &&
has_double_tail_k_block_loop == compile_param.HasDoubleTailKBlockLoop))
return false;
}
// check A blockwise copy
{
const auto block_slice_lengths =
std::array<int, 5>{compile_param.GK0PerBlock, GM0, 1, GM11, GK1};
const auto& cluster_lengths =
compile_param.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1;
const auto& thread_slice_lengths =
compile_param.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1;
const auto& src_vector_lengths =
compile_param.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1;
const auto& dst_vector_lengths =
compile_param.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1;
// check number of working thread
const int num_work_thread = std::accumulate(
cluster_lengths.begin(), cluster_lengths.end(), 1, std::multiplies<int>{});
if(!(compile_param.BlockSize >= num_work_thread))
return false;
// check block slice lengths vs thread slice lengths vs cluster lengths
for(int i = 0; i < 5; ++i)
{
if(!(cluster_lengths[i] * thread_slice_lengths[i] == block_slice_lengths[i]))
return false;
}
// check thread slice lengths vs vector lengths
for(int i = 0; i < 5; ++i)
{
if(!(thread_slice_lengths[i] % src_vector_lengths[i] == 0))
return false;
if(!(thread_slice_lengths[i] % dst_vector_lengths[i] == 0))
return false;
}
// check Src vectorization, GK0 is global mem vector dim
if(!(src_vector_lengths[1] == 1 && src_vector_lengths[2] == 1 &&
src_vector_lengths[3] == 1 && src_vector_lengths[4] == 1))
return false;
// check Dst vectorization, {GM11, GK1} are LDS vector dims
if(dst_vector_lengths[4] == GK1)
{ // vectorize on {GM11, GK1}
if(!(GM11 % dst_vector_lengths[3] == 0))
return false;
}
else
{ // vectorize on {GK1} only
if(!(GK1 % dst_vector_lengths[4] == 0))
return false;
if(!(dst_vector_lengths[3] == 1))
return false;
}
}
// check B blockwise copy
{
const auto block_slice_lengths =
std::array<int, 5>{compile_param.GK0PerBlock, GN0, 1, GN11, GK1};
const auto& cluster_lengths =
compile_param.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1;
const auto& thread_slice_lengths =
compile_param.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1;
const auto& src_vector_lengths =
compile_param.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1;
const auto& dst_vector_lengths =
compile_param.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1;
// check number of working thread
const int num_work_thread = std::accumulate(
cluster_lengths.begin(), cluster_lengths.end(), 1, std::multiplies<int>{});
if(!(compile_param.BlockSize >= num_work_thread))
return false;
// check block slice lengths vs thread slice lengths vs cluster lengths
for(int i = 0; i < 5; ++i)
{
if(!(cluster_lengths[i] * thread_slice_lengths[i] == block_slice_lengths[i]))
return false;
}
// check thread slice lengths vs vector lengths
for(int i = 0; i < 5; ++i)
{
if(!(thread_slice_lengths[i] % src_vector_lengths[i] == 0 &&
thread_slice_lengths[i] % dst_vector_lengths[i] == 0))
return false;
}
// check Src vectorization: {GN11} is global mem vector dim
if(!(src_vector_lengths[0] == 1 && src_vector_lengths[1] == 1 &&
src_vector_lengths[2] == 1 && src_vector_lengths[4] == 1))
return false;
// check Src tensor layout related vectorization
if(Y == 1 && X == 1 && conv_problem_desc.ConvStrideH == 1 &&
conv_problem_desc.ConvStrideW == 1 && conv_problem_desc.InLeftPadH == 0 &&
conv_problem_desc.InLeftPadW == 0 && conv_problem_desc.InRightPadH == 0 &&
conv_problem_desc.InRightPadW == 0)
{
if(!((Ho * Wo) % src_vector_lengths[3] == 0))
return false;
}
else if(conv_problem_desc.ConvStrideW == 1 && conv_problem_desc.InLeftPadW == 0 &&
conv_problem_desc.InRightPadW == 0)
{
if(!(Wo % src_vector_lengths[3] == 0))
return false;
}
else
{
if(!(src_vector_lengths[3] == 1))
return false;
}
// check Dst vectorization: {GN11, GK1} are LDS vector dims
if(dst_vector_lengths[4] == GK1)
{ // vectorize on {GN11, GK1}
if(!(GN11 % dst_vector_lengths[3] == 0))
return false;
}
else
{ // vectorize on {GK1} only
if(!(dst_vector_lengths[3] == 1))
return false;
if(!(GK1 % dst_vector_lengths[4] == 0))
return false;
}
}
// check blockwise GEMM
{
const int BM10 = std::accumulate(compile_param.BM10BN10ThreadClusterBM10Xs.begin(),
compile_param.BM10BN10ThreadClusterBM10Xs.end(),
1,
std::multiplies<int>{});
const int BN10 = std::accumulate(compile_param.BM10BN10ThreadClusterBN10Xs.begin(),
compile_param.BM10BN10ThreadClusterBN10Xs.end(),
1,
std::multiplies<int>{});
if(!(compile_param.BlockSize == BM10 * BN10))
return false;
const int BM = GM0 * GM11;
const int BN = GN0 * GN11;
const int BM1 = BM10 * BM11;
const int BN1 = BN10 * BN11;
if(!(BM % BM1 == 0 && BN % BN1 == 0))
return false;
const int BM0 = BM / BM1;
const int BN0 = BN / BN1;
// blockwise GEMM currently only support BM0 == 2 && BN0 == 2
if(!(BM0 == 2 && BN0 == 2))
return false;
if(!(compile_param.GK0PerBlock % compile_param.BK0PerThread == 0))
return false;
}
// check C threadwise copy
{
// {BN11} or {BN} or {BN1} or {GN11} is Dst vector dim
const int dst_vector_len_gn11 = compile_param.CThreadTransferDstScalarPerVector;
// check slice length vs Dst vector length:
if(!(BN11 % dst_vector_len_gn11 == 0 && GN11 % dst_vector_len_gn11 == 0))
return false;
// check Dst memory layout related vectorization:
if(!((Ho * Wo) % compile_param.CThreadTransferDstScalarPerVector == 0))
return false;
}
return true;
};
static int GetBlockSize(const ConvolutionProblemDescriptor&,
const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param)
{
return compile_param.BlockSize;
}
static int GetGridSize(const ConvolutionProblemDescriptor& conv_problem_desc,
const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param)
{
const int N = conv_problem_desc.N;
const int K = conv_problem_desc.K;
const int Ho = conv_problem_desc.Ho;
const int Wo = conv_problem_desc.Wo;
const int N0 = compile_param.GN0;
const int N1 = N / N0;
const int GM1 = K;
const int GN1 = N1 * Ho * Wo;
const int GM11 = compile_param.GM1PerBlockGM11;
const int GN11 = compile_param.GN1PerBlockGN11;
const int GM10 = GM1 / GM11;
const int GN10 = GN1 / GN11;
return GM10 * GN10;
}
static std::size_t GetWorkSpaceSize(const ConvolutionProblemDescriptor&,
const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw&)
{
// workspace is used for save transformed tensor descritpors created by prepare kernel
return 4096L;
}
static std::size_t GetMaxWorkSpaceSize(const ConvolutionProblemDescriptor&) { return 4096L; }
static auto GetTunableList()
{
return generate_tunable_list_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw();
}
};
} // namespace driver
} // namespace ck
#endif
#ifndef CONV_TUNABLE_FWD_V4R4_DLOPS_NCHW_KCYX_NKHW_HPP
#define CONV_TUNABLE_FWD_V4R4_DLOPS_NCHW_KCYX_NKHW_HPP
struct tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw
{
int BlockSize;
int MPerBlock;
int NPerBlock;
int KPerBlock;
int M1PerThread;
int N1PerThread;
int KPerThread;
int M1N1ThreadClusterM10;
int M1N1ThreadClusterN10;
int M1N1ThreadClusterM11;
int M1N1ThreadClusterN11;
std::array<int, 3> ABlockTransferThreadSliceLengths_K_M0_M1;
std::array<int, 3> ABlockTransferThreadClusterLengths_K_M0_M1;
std::array<int, 3> ABlockTransferThreadClusterArrangeOrder;
std::array<int, 3> ABlockTransferSrcAccessOrder;
int ABlockTransferSrcVectorDim;
int ABlockTransferSrcScalarPerVector;
int ABlockTransferDstScalarPerVector_M1;
bool AThreadTransferSrcResetCoordinateAfterRun;
std::array<int, 3> BBlockTransferThreadSliceLengths_K_N0_N1;
std::array<int, 3> BBlockTransferThreadClusterLengths_K_N0_N1;
std::array<int, 3> BBlockTransferThreadClusterArrangeOrder;
std::array<int, 3> BBlockTransferSrcAccessOrder;
int BBlockTransferSrcVectorDim;
int BBlockTransferSrcScalarPerVector;
int BBlockTransferDstScalarPerVector_N1;
bool BThreadTransferSrcResetCoordinateAfterRun;
std::array<int, 6> CThreadTransferSrcDstAccessOrder;
int CThreadTransferSrcDstVectorDim;
int CThreadTransferDstScalarPerVector;
};
static tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw
default_tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw = {
256, 128, 128, 8, 4, 4, 1,
8, 8, 2, 2, {4, 1, 1}, {2, 1, 128}, {2, 1, 0},
{2, 1, 0}, 0, 4, 1, false, {4, 1, 1}, {2, 1, 128},
{0, 1, 2}, {0, 1, 2}, 2, 1, 1, false, {3, 4, 5, 0, 1, 2},
5, 1};
#endif
#ifndef CONV_TUNABLE_FWD_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
#define CONV_TUNABLE_FWD_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
struct tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
{
int BlockSize;
int MPerBlock;
int NPerBlock;
int KPerBlock;
int MPerXDL;
int NPerXDL;
int K1;
int MRepeat;
int NRepeat;
std::array<int, 3> ABlockTransferThreadSliceLengths_K0_M_K1;
std::array<int, 3> ABlockTransferThreadClusterLengths_K0_M_K1;
std::array<int, 3> ABlockTransferThreadClusterArrangeOrder;
std::array<int, 3> ABlockTransferSrcAccessOrder;
int ABlockTransferSrcVectorDim;
int ABlockTransferSrcScalarPerVector;
int ABlockTransferDstScalarPerVector_K1;
bool AThreadTransferSrcResetCoordinateAfterRun;
std::array<int, 3> BBlockTransferThreadSliceLengths_K0_N_K1;
std::array<int, 3> BBlockTransferThreadClusterLengths_K0_N_K1;
std::array<int, 3> BBlockTransferThreadClusterArrangeOrder;
std::array<int, 3> BBlockTransferSrcAccessOrder;
int BBlockTransferSrcVectorDim;
int BBlockTransferSrcScalarPerVector;
int BBlockTransferDstScalarPerVector_K1;
bool BThreadTransferSrcResetCoordinateAfterRun;
std::array<int, 8> CThreadTransferSrcDstAccessOrder;
int CThreadTransferSrcDstVectorDim;
int CThreadTransferDstScalarPerVector;
};
static tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw = {
256, // BlockSize
128, // MPerBlock,
128, // NPerBlock,
4, // KPerBlock,
32, // MPerXDL,
32, // NPerXDL,
4, // K1,
2, // MRepeat,
2, // NRepeat,
{1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1,
{4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1,
{1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder,
{1, 0, 2}, // ABlockTransferSrcAccessOrder,
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector,
4, // ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
{1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1,
{4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1,
{0, 2, 1}, // BBlockTransferThreadClusterArrangeOrder,
{1, 0, 2}, // BBlockTransferSrcAccessOrder,
1, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_K1
false, // BThreadTransferSrcResetCoordinateAfterRun
{3, 0, 1, 2, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder
7, // CThreadTransferSrcDstVectorDim,
1 // CThreadTransferDstScalarPerVector
};
#endif
#ifndef CONV_TUNABLE_FWD_V4R4_XDLOPS_NHWC_KYXC_NHWK_HPP
#define CONV_TUNABLE_FWD_V4R4_XDLOPS_NHWC_KYXC_NHWK_HPP
struct tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk
{
int BlockSize;
int MPerBlock;
int NPerBlock;
int KPerBlock;
int MPerWave;
int NPerWave;
int K1;
int MRepeat;
int NRepeat;
std::array<int, 3> ABlockTransferThreadSliceLengths_K0_M_K1;
std::array<int, 3> ABlockTransferThreadClusterLengths_K0_M_K1;
std::array<int, 3> ABlockTransferThreadClusterArrangeOrder;
std::array<int, 3> ABlockTransferSrcAccessOrder;
int ABlockTransferSrcVectorDim;
int ABlockTransferSrcScalarPerVector;
int ABlockTransferDstScalarPerVector_K1;
bool AThreadTransferSrcResetCoordinateAfterRun;
std::array<int, 3> BBlockTransferThreadSliceLengths_K0_N_K1;
std::array<int, 3> BBlockTransferThreadClusterLengths_K0_N_K1;
std::array<int, 3> BBlockTransferThreadClusterArrangeOrder;
std::array<int, 3> BBlockTransferSrcAccessOrder;
int BBlockTransferSrcVectorDim;
int BBlockTransferSrcScalarPerVector;
int BBlockTransferDstScalarPerVector_K1;
bool BThreadTransferSrcResetCoordinateAfterRun;
std::array<int, 8> CThreadTransferSrcDstAccessOrder;
int CThreadTransferSrcDstVectorDim;
int CThreadTransferDstScalarPerVector;
};
static tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk
default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk = {
256, // BlockSize
128, // MPerBlock,
128, // NPerBlock,
4, // KPerBlock,
32, // MPerWave,
32, // NPerWave,
4, // K1,
2, // MRepeat,
2, // NRepeat,
{1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1,
{4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1,
{1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder,
{1, 0, 2}, // ABlockTransferSrcAccessOrder,
2, // ABlockTransferSrcVectorDim
4, // ABlockTransferSrcScalarPerVector,
4, // ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
{1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1,
{4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1,
{1, 0, 2}, // BBlockTransferThreadClusterArrangeOrder,
{1, 0, 2}, // BBlockTransferSrcAccessOrder,
2, // BBlockTransferSrcVectorDim
4, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_K1
false, // BThreadTransferSrcResetCoordinateAfterRun
{2, 3, 0, 1, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder
7, // CThreadTransferSrcDstVectorDim,
1 // CThreadTransferDstScalarPerVector
};
#endif
#ifndef CONVOLUTION_PROBLEM_DESCRIPTOR
#define CONVOLUTION_PROBLEM_DESCRIPTOR
namespace ck {
namespace driver {
struct ConvolutionProblemDescriptor
{
ConvolutionProblemDescriptor() = default;
ConvolutionProblemDescriptor(int N_,
int K_,
int C_,
int Y_,
int X_,
int Hi_,
int Wi_,
int Ho_,
int Wo_,
int ConvStrideH_,
int ConvStrideW_,
int ConvDilationH_,
int ConvDilationW_,
int InLeftPadH_,
int InLeftPadW_,
int InRightPadH_,
int InRightPadW_,
ck::DataTypeEnum_t InDataTypeEnum_,
ck::DataTypeEnum_t WeiDataTypeEnum_,
ck::DataTypeEnum_t OutDataTypeEnum_)
: N{N_},
K{K_},
C{C_},
Y{Y_},
X{X_},
Hi{Hi_},
Wi{Wi_},
Ho{Ho_},
Wo{Wo_},
ConvStrideH{ConvStrideH_},
ConvStrideW{ConvStrideW_},
ConvDilationH{ConvDilationH_},
ConvDilationW{ConvDilationW_},
InLeftPadH{InLeftPadH_},
InLeftPadW{InLeftPadW_},
InRightPadH{InRightPadH_},
InRightPadW{InRightPadW_},
InDataTypeEnum{InDataTypeEnum_},
WeiDataTypeEnum{WeiDataTypeEnum_},
OutDataTypeEnum{OutDataTypeEnum_}
{
}
int N;
int K;
int C;
int Y;
int X;
int Hi;
int Wi;
int Ho;
int Wo;
int ConvStrideH;
int ConvStrideW;
int ConvDilationH;
int ConvDilationW;
int InLeftPadH;
int InLeftPadW;
int InRightPadH;
int InRightPadW;
ck::DataTypeEnum_t InDataTypeEnum;
ck::DataTypeEnum_t WeiDataTypeEnum;
ck::DataTypeEnum_t OutDataTypeEnum;
std::size_t CalculateFlop() const { return 2L * N * K * C * Y * X * Ho * Wo; }
};
} // namespace driver
} // namespace ck
#endif
#ifndef CK_SOLVER_COMMON_HPP
#define CK_SOLVER_COMMON_HPP
namespace ck {
namespace driver {
// greatest common divisor, aka highest common factor
inline int gcd(int x, int y)
{
if(x < 0)
{
return gcd(-x, y);
}
else if(y < 0)
{
return gcd(x, -y);
}
else if(x == y || x == 0)
{
return y;
}
else if(y == 0)
{
return x;
}
else if(x > y)
{
return gcd(x % y, y);
}
else
{
return gcd(x, y % x);
}
}
template <typename X,
typename... Ys,
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
auto gcd(X x, Ys... ys)
{
return gcd(x, gcd(ys...));
}
} // namespace driver
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment