Commit 669df2d3 authored by ltqin's avatar ltqin
Browse files

start device

parent cfc80c01
...@@ -54,14 +54,14 @@ template < ...@@ -54,14 +54,14 @@ template <
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl> index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvWrw<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation> : public DeviceConvWrw<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
{ {
using DeviceOp = DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; using DeviceOp = DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
using ADataType = InDataType; using ADataType = InDataType;
using BDataType = WeiDataType; using BDataType = OutDataType;
using CDataType = OutDataType; using CDataType = WeiDataType;
// TODO make A/B datatype different // TODO make A/B datatype different
using ABDataType = InDataType; using ABDataType = InDataType;
...@@ -432,10 +432,11 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -432,10 +432,11 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
OutElementwiseOperation, OutElementwiseOperation,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
K0PerBlock, K0PerBlock * K1,
K1, // AK1
K1, // BK1
MPerXdl, MPerXdl,
NPerXdl, NPerXdl,
K1,
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
...@@ -491,8 +492,8 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -491,8 +492,8 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
M01_{M01}, M01_{M01},
N01_{N01}, N01_{N01},
in_element_op_{in_element_op}, in_element_op_{in_element_op},
wei_element_op_{out_element_op}, wei_element_op_{wei_element_op},
out_element_op_{wei_element_op}, out_element_op_{out_element_op},
Conv_N_{N}, Conv_N_{N},
Conv_K_{K}, Conv_K_{K},
Conv_C_{C}, Conv_C_{C},
...@@ -525,7 +526,8 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -525,7 +526,8 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c_grid_desc_m_n_); c_grid_desc_m_n_);
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
} }
} }
...@@ -538,7 +540,7 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -538,7 +540,7 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
typename GridwiseGemm:: typename GridwiseGemm::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
InElementwiseOperation in_element_op_; InElementwiseOperation in_element_op_;
...@@ -628,7 +630,7 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -628,7 +630,7 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
...@@ -662,7 +664,7 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -662,7 +664,7 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::Block2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
...@@ -838,7 +840,7 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -838,7 +840,7 @@ struct DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" str << "DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -35,8 +35,8 @@ static constexpr auto ConvFwdDefault = ...@@ -35,8 +35,8 @@ static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default;
// clang-format off // clang-format off
using DeviceConvFwdInstance = ck::tensor_operation::device:: using DeviceConvWrWInstance = ck::tensor_operation::device::
DeviceConv2dWrwXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
InDataType, // InDataType InDataType, // InDataType
WeiDataType, // WeiDataType WeiDataType, // WeiDataType
OutDataType, // OutDataType OutDataType, // OutDataType
...@@ -205,7 +205,7 @@ int main(int argc, char* argv[]) ...@@ -205,7 +205,7 @@ int main(int argc, char* argv[])
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
// do GEMM // do GEMM
auto conv = DeviceConvFwdInstance{}; auto conv = DeviceConvWrWInstance{};
auto invoker = conv.MakeInvoker(); auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), auto argument = conv.MakeArgument(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()), static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment