// SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" using InDataType = float; using WeiDataType = float; using AccDataType = float; using CShuffleDataType = float; using OutDataType = float; template using S = ck::Sequence; using InElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; using OutElementOp = ck::tensor_operation::element_wise::PassThrough; static constexpr auto ConvSpec = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; template using DeviceGroupedConvNDFwdInstance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< NDimSpatial, InLayout, WeiLayout, ck::Tuple<>, OutLayout, InDataType, WeiDataType, AccDataType, CShuffleDataType, ck::Tuple<>, OutDataType, InElementOp, WeiElementOp, OutElementOp, ConvSpec, // ConvForwardSpecialization GemmSpec, // GemmSpecialization 1, // 256, // BlockSize 128, // MPerBlock 256, // NPerBlock 16, // KPerBlock 4, // AK1 4, // BK1 32, // MPerXdl 32, // NPerXdl 2, // MXdlPerWave 4, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder 2, // ABlockTransferSrcVectorDim 4, // ABlockTransferSrcScalarPerVector 4, // ABlockTransferDstScalarPerVector_AK1 1, // ABlockLdsExtraM S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder 2, // BBlockTransferSrcVectorDim 4, // BBlockTransferSrcScalarPerVector 4, // BBlockTransferDstScalarPerVector_BK1 1, // BBlockLdsExtraN 1, 1, S<1, 16, 1, 16>, 4>; #include "run_convnd_fwd_example.inc" int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; }