Commit df22ba01 authored by ltqin's avatar ltqin
Browse files

start to use atomic add

parent 162359b6
...@@ -11,48 +11,47 @@ ...@@ -11,48 +11,47 @@
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v3r1.hpp" #include "gridwise_gemm_xdlops_v2r4r2.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] // out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template < template <typename InDataType,
typename InDataType, typename WeiDataType,
typename WeiDataType, typename OutDataType,
typename OutDataType, typename AccDataType,
typename AccDataType, typename InElementwiseOperation,
typename InElementwiseOperation, typename WeiElementwiseOperation,
typename WeiElementwiseOperation, typename OutElementwiseOperation,
typename OutElementwiseOperation, ck::index_t BlockSize,
ck::index_t BlockSize, ck::index_t MPerBlock,
ck::index_t MPerBlock, ck::index_t NPerBlock,
ck::index_t NPerBlock, ck::index_t K0PerBlock,
ck::index_t K0PerBlock, ck::index_t K1,
ck::index_t K1, ck::index_t MPerXdl,
ck::index_t MPerXdl, ck::index_t NPerXdl,
ck::index_t NPerXdl, ck::index_t MXdlPerWave,
ck::index_t MXdlPerWave, ck::index_t NXdlPerWave,
ck::index_t NXdlPerWave, typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterLengths_K0_M_K1, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcAccessOrder, ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcVectorDim, ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferSrcScalarPerVector, ck::index_t ABlockTransferDstScalarPerVector_K1,
ck::index_t ABlockTransferDstScalarPerVector_K1, bool ABlockLdsAddExtraM,
bool ABlockLdsAddExtraM, typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterLengths_K0_N_K1, typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcAccessOrder, ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcVectorDim, ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferSrcScalarPerVector, ck::index_t BBlockTransferDstScalarPerVector_K1,
ck::index_t BBlockTransferDstScalarPerVector_K1, bool BBlockLdsAddExtraN,
bool BBlockLdsAddExtraN, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvWrw<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation> : public DeviceConvWrw<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
{ {
...@@ -62,6 +61,10 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -62,6 +61,10 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
using BDataType = InDataType; using BDataType = InDataType;
using CDataType = WeiDataType; using CDataType = WeiDataType;
using AElementwiseOperation = OutElementwiseOperation;
using BElementwiseOperation = InElementwiseOperation;
using CElementwiseOperation = WeiElementwiseOperation;
// TODO make A/B datatype different // TODO make A/B datatype different
using ABDataType = InDataType; using ABDataType = InDataType;
...@@ -87,7 +90,8 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -87,7 +90,8 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) std::vector<ck::index_t> input_right_pads,
ck::index_t batch_k)
{ {
using namespace ck; using namespace ck;
...@@ -115,7 +119,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -115,7 +119,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
const index_t GemmKTotal = N * Ho * Wo; const index_t GemmKTotal = N * Ho * Wo;
const index_t GemmM = K; const index_t GemmM = K;
const index_t GemmN = C * X * Y; const index_t GemmN = C * X * Y;
const index_t GemmK0 = GemmKTotal / GemmK1Number; const index_t GemmK0 = GemmKTotal / GemmK1Number / batch_k;
const auto out_gemmktotal_gemmm_grid_desc = const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
...@@ -123,12 +127,12 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -123,12 +127,12 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
// A: output tensor // A: output tensor
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmm_grid_desc, out_gemmktotal_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), make_tuple(make_unmerge_transform(make_tuple(batch_k, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)), make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// B: input tensor // B: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
...@@ -156,49 +160,48 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -156,49 +160,48 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
make_merge_transform(make_tuple(N, Ho, Wo))), make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc, in_gemmktotal_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), make_tuple(make_unmerge_transform(make_tuple(batch_k, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)), make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor // C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmm_gemmn_grid_desc);
} }
using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1));
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>; using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>; using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>; using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1< using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
BlockSize, BlockSize,
ABDataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::AtomicAdd,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
InElementwiseOperation, AElementwiseOperation,
WeiElementwiseOperation, BElementwiseOperation,
OutElementwiseOperation, CElementwiseOperation,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
K0PerBlock * K1, K0PerBlock,
K1, // AK1
K1, // BK1
MPerXdl, MPerXdl,
NPerXdl, NPerXdl,
K1,
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
...@@ -219,10 +222,15 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -219,10 +222,15 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
BBlockLdsAddExtraN, BBlockLdsAddExtraN,
CShuffleMXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, CBlockTransferScalarPerVector_NWaveNPerXdl,
CBlockTransferScalarPerVector_NWaveNPerXdl>; CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
// Argument // Argument
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
using Block2CTileMap =
decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1));
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const InDataType* p_in_grid, Argument(const InDataType* p_in_grid,
...@@ -242,14 +250,15 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -242,14 +250,15 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
ck::index_t N01, ck::index_t N01,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) OutElementwiseOperation out_element_op,
ck::index_t split_k)
: p_a_grid_{p_out_grid}, : p_a_grid_{p_out_grid},
p_b_grid_{p_in_grid}, p_b_grid_{p_in_grid},
p_c_grid_{p_wei_grid}, p_c_grid_{p_wei_grid},
a_grid_desc_k0_m_k1_{}, a_grid_desc_kbatch_k0_m_k1_{},
b_grid_desc_k0_n_k1_{}, b_grid_desc_kbatch_k0_n_k1_{},
c_grid_desc_m_n_{}, c_grid_desc_m_n_{},
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{}, block_2_ctile_map_{},
M01_{M01}, M01_{M01},
N01_{N01}, N01_{N01},
...@@ -262,7 +271,8 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -262,7 +271,8 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
filter_spatial_lengths_{filter_spatial_lengths}, filter_spatial_lengths_{filter_spatial_lengths},
conv_filter_strides_{conv_filter_strides}, conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads}, input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads} input_right_pads_{input_right_pads},
k_batch_{split_k}
{ {
const auto descs = const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N,
...@@ -274,35 +284,36 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -274,35 +284,36 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads,
k_batch_);
a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1]; a_grid_desc_kbatch_k0_m_k1_ = descs[I0];
c_grid_desc_m_n_ = descs[I2]; b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
b_grid_desc_kbatch_k0_n_k1_,
c_grid_desc_m_n_,
M01_,
N01_))
{ {
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm:: GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_);
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
c_grid_desc_m_n_);
block_2_ctile_map_ = block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
} }
} }
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm:: CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_;
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl Block2CTileMap block_2_ctile_map_;
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; ;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
InElementwiseOperation a_element_op_; InElementwiseOperation a_element_op_;
...@@ -316,6 +327,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -316,6 +327,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
std::vector<index_t> conv_filter_strides_; std::vector<index_t> conv_filter_strides_;
std::vector<index_t> input_left_pads_; std::vector<index_t> input_left_pads_;
std::vector<index_t> input_right_pads_; std::vector<index_t> input_right_pads_;
index_t k_batch_;
}; };
// Invoker // Invoker
...@@ -326,42 +338,22 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -326,42 +338,22 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, int nrepeat = 1)
{ {
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout
<< "arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_"
"nwavenperxdl_{ "
<< arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
.GetLength(I0)
<< ", "
<< arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
.GetLength(I1)
<< ", "
<< arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
.GetLength(I2)
<< ", "
<< arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
.GetLength(I3)
<< ", "
<< arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
.GetLength(I4)
<< ", "
<< arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
.GetLength(I5)
<< "}" << std::endl;
} }
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
arg.M01_, arg.M01_,
arg.N01_)) arg.N01_))
...@@ -369,10 +361,10 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -369,10 +361,10 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
} }
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_, kbatch);
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
...@@ -380,71 +372,67 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -380,71 +372,67 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
{ {
const auto kernel = kernel_gemm_xdlops_v3r1< const auto kernel = kernel_gemm_xdlops_v2r4r2<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t< remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
typename GridwiseGemm::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
OutElementwiseOperation, OutElementwiseOperation,
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, remove_reference_t<DeviceOp::Block2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel( ave_time =
kernel, launch_and_time_kernel(kernel,
nrepeat, nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
else else
{ {
const auto kernel = kernel_gemm_xdlops_v3r1< const auto kernel = kernel_gemm_xdlops_v2r4r2<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t< remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
typename GridwiseGemm::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
OutElementwiseOperation, OutElementwiseOperation,
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, remove_reference_t<DeviceOp::Block2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel( ave_time =
kernel, launch_and_time_kernel(kernel,
nrepeat, nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
return ave_time; return ave_time;
...@@ -465,7 +453,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -465,7 +453,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// vector load A/B matrix from global memory // vector load A/B matrix from global memory
if(!(ABlockTransferSrcVectorDim == 1 && BBlockTransferSrcVectorDim == 1 && if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 &&
arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 &&
arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
{ {
...@@ -479,8 +467,8 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -479,8 +467,8 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
} }
// Gridwise GEMM size // Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
arg.M01_, arg.M01_,
arg.N01_); arg.N01_);
...@@ -506,7 +494,8 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -506,7 +494,8 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) OutElementwiseOperation out_element_op,
ck::index_t split_k)
{ {
return Argument{p_in_grid, return Argument{p_in_grid,
p_wei_grid, p_wei_grid,
...@@ -525,7 +514,8 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -525,7 +514,8 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
1, 1,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op}; out_element_op,
split_k};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -565,7 +555,8 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -565,7 +555,8 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
1, 1,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op); out_element_op,
1);
} }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......
...@@ -50,23 +50,23 @@ using DeviceConvWrWInstance = ck::tensor_operation::device:: ...@@ -50,23 +50,23 @@ using DeviceConvWrWInstance = ck::tensor_operation::device::
32, // NPerXdl 32, // NPerXdl
2, // MXdlPerWave 2, // MXdlPerWave
2, // NXdlPerWave 2, // NXdlPerWave
S<4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
1, // ABlockTransferSrcVectorDim 2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector 8, // ABlockTransferSrcScalarPerVector
2, // ABlockTransferDstScalarPerVector_K1 2, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM true, // ABlockLdsAddExtraM
S<4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1 S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
1, // BBlockTransferSrcVectorDim 2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferSrcScalarPerVector
2, // BBlockTransferDstScalarPerVector_K1 2, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN true, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl S<1, 16, 1, 4>, //
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on // clang-format on
...@@ -82,7 +82,7 @@ int main(int argc, char* argv[]) ...@@ -82,7 +82,7 @@ int main(int argc, char* argv[])
// Conv shape // Conv shape
ck::index_t N = 128; ck::index_t N = 128;
ck::index_t K = 256; ck::index_t K = 256;
ck::index_t C = 192; ck::index_t C = 128;
ck::index_t Y = 3; ck::index_t Y = 3;
ck::index_t X = 3; ck::index_t X = 3;
ck::index_t Hi = 71; ck::index_t Hi = 71;
...@@ -95,6 +95,7 @@ int main(int argc, char* argv[]) ...@@ -95,6 +95,7 @@ int main(int argc, char* argv[])
ck::index_t in_left_pad_w = 1; ck::index_t in_left_pad_w = 1;
ck::index_t in_right_pad_h = 1; ck::index_t in_right_pad_h = 1;
ck::index_t in_right_pad_w = 1; ck::index_t in_right_pad_w = 1;
ck::index_t split_k = 1;
if(argc == 4) if(argc == 4)
{ {
...@@ -102,7 +103,7 @@ int main(int argc, char* argv[]) ...@@ -102,7 +103,7 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); nrepeat = std::stoi(argv[3]);
} }
else if(argc == 19) else if(argc == 20)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
...@@ -123,6 +124,7 @@ int main(int argc, char* argv[]) ...@@ -123,6 +124,7 @@ int main(int argc, char* argv[])
in_left_pad_w = std::stoi(argv[16]); in_left_pad_w = std::stoi(argv[16]);
in_right_pad_h = std::stoi(argv[17]); in_right_pad_h = std::stoi(argv[17]);
in_right_pad_w = std::stoi(argv[18]); in_right_pad_w = std::stoi(argv[18]);
split_k = std::stoi(argv[19]);
} }
else else
{ {
...@@ -185,12 +187,13 @@ int main(int argc, char* argv[]) ...@@ -185,12 +187,13 @@ int main(int argc, char* argv[])
case 0: break; case 0: break;
case 1: case 1:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5}); out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
break; break;
default: default:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0}); in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{1});
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5}); out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
} }
wei_k_c_y_x_device_result.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{0});
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * DeviceMem wei_device_buf(sizeof(WeiDataType) *
...@@ -199,6 +202,9 @@ int main(int argc, char* argv[]) ...@@ -199,6 +202,9 @@ int main(int argc, char* argv[])
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
wei_device_buf.ToDevice(wei_k_c_y_x_device_result.mData.data());
LogRangeAsType<float>(std::cout << "wei_device(before): ", wei_k_c_y_x_device_result.mData, ",")
<< std::endl;
// do GEMM // do GEMM
auto conv = DeviceConvWrWInstance{}; auto conv = DeviceConvWrWInstance{};
...@@ -218,7 +224,8 @@ int main(int argc, char* argv[]) ...@@ -218,7 +224,8 @@ int main(int argc, char* argv[])
input_right_pads, input_right_pads,
InElementOp{}, InElementOp{},
WeiElementOp{}, WeiElementOp{},
OutElementOp{}); OutElementOp{},
split_k);
if(!conv.IsSupportedArgument(argument)) if(!conv.IsSupportedArgument(argument))
{ {
...@@ -262,6 +269,16 @@ int main(int argc, char* argv[]) ...@@ -262,6 +269,16 @@ int main(int argc, char* argv[])
wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data()); wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data());
if(1)
{
LogRangeAsType<float>(std::cout << "out: ", out_n_k_ho_wo.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "in : ", in_n_c_hi_wi.mData, ",") << std::endl;
LogRangeAsType<float>(
std::cout << "wei_device(after): ", wei_k_c_y_x_device_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",")
<< std::endl;
}
check_error(wei_k_c_y_x_host_result, wei_k_c_y_x_device_result); check_error(wei_k_c_y_x_host_result, wei_k_c_y_x_device_result);
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment