Commit 669df2d3 authored by ltqin's avatar ltqin
Browse files

start device

parent cfc80c01
......@@ -54,14 +54,14 @@ template <
index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvWrw<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
{
using DeviceOp = DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
using DeviceOp = DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
using ADataType = InDataType;
using BDataType = WeiDataType;
using CDataType = OutDataType;
using BDataType = OutDataType;
using CDataType = WeiDataType;
// TODO make A/B datatype different
using ABDataType = InDataType;
......@@ -432,10 +432,11 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
OutElementwiseOperation,
MPerBlock,
NPerBlock,
K0PerBlock,
K0PerBlock * K1,
K1, // AK1
K1, // BK1
MPerXdl,
NPerXdl,
K1,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
......@@ -491,8 +492,8 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
M01_{M01},
N01_{N01},
in_element_op_{in_element_op},
wei_element_op_{out_element_op},
out_element_op_{wei_element_op},
wei_element_op_{wei_element_op},
out_element_op_{out_element_op},
Conv_N_{N},
Conv_K_{K},
Conv_C_{C},
......@@ -525,7 +526,8 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
}
}
......@@ -538,7 +540,7 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
typename GridwiseGemm::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
InElementwiseOperation in_element_op_;
......@@ -628,7 +630,7 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
ave_time = launch_and_time_kernel(
......@@ -662,7 +664,7 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>;
ave_time = launch_and_time_kernel(
......@@ -838,7 +840,7 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
auto str = std::stringstream();
// clang-format off
str << "DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
str << "DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -35,8 +35,8 @@ static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default;
// clang-format off
using DeviceConvFwdInstance = ck::tensor_operation::device::
DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
using DeviceConvWrWInstance = ck::tensor_operation::device::
DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
InDataType, // InDataType
WeiDataType, // WeiDataType
OutDataType, // OutDataType
......@@ -205,7 +205,7 @@ int main(int argc, char* argv[])
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
// do GEMM
auto conv = DeviceConvFwdInstance{};
auto conv = DeviceConvWrWInstance{};
auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment