Commit 3df07c27 authored by rocking's avatar rocking
Browse files

Use 1D global memory for count

parent 39dedce7
......@@ -33,7 +33,8 @@ template <typename GridwiseGemmWelford,
typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock,
typename MeanVarGridDescriptor_MBlock_MPerBlock_NBlock,
typename CountGridDescriptor_MBlock_MPerBlock_NBlock,
typename Block2ETileMap,
bool HasMainKBlockLoop>
__global__ void
......@@ -57,8 +58,10 @@ __global__ void
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
mean_var_count_grid_desc_mblock_mperblock_nblock,
const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock
mean_var_grid_desc_mblock_mperblock_nblock,
const CountGridDescriptor_MBlock_MPerBlock_NBlock
count_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap block_2_etile_map,
index_t NRaw)
{
......@@ -81,7 +84,8 @@ __global__ void
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
mean_var_count_grid_desc_mblock_mperblock_nblock,
mean_var_grid_desc_mblock_mperblock_nblock,
count_grid_desc_mblock_mperblock_nblock,
block_2_etile_map,
NRaw);
#else
......@@ -99,7 +103,8 @@ __global__ void
ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = mean_var_count_grid_desc_mblock_mperblock_nblock;
ignore = mean_var_grid_desc_mblock_mperblock_nblock;
ignore = count_grid_desc_mblock_mperblock_nblock;
ignore = block_2_etile_map;
ignore = NRaw;
#endif
......@@ -114,7 +119,8 @@ template <typename GridwiseWelfordLayernorm,
typename BetaDataType,
typename ComputeDataType,
typename EHGridDesc_M_N,
typename LayernormMeanVarCountGridDesc_M_NBlock,
typename LayernormMeanVarGridDesc_M_NBlock,
typename LayernormCountGridDesc_M_NBlock,
typename GammaBetaGridDesc_N,
typename HElementwiseOperation>
__global__ void
......@@ -131,7 +137,8 @@ __global__ void
HDataType* __restrict__ p_h_grid,
const EHGridDesc_M_N e_grid_desc_m_n,
const EHGridDesc_M_N h_grid_desc_m_n,
const LayernormMeanVarCountGridDesc_M_NBlock mean_var_count_grid_desc_m_nblock,
const LayernormMeanVarGridDesc_M_NBlock mean_var_grid_desc_m_nblock,
const LayernormCountGridDesc_M_NBlock count_grid_desc_m_nblock,
const GammaBetaGridDesc_N gamma_grid_desc_n,
const GammaBetaGridDesc_N beta_grid_desc_n,
index_t numMeanVarCountBlockTileIteration_N,
......@@ -148,7 +155,8 @@ __global__ void
p_h_grid,
e_grid_desc_m_n,
h_grid_desc_m_n,
mean_var_count_grid_desc_m_nblock,
mean_var_grid_desc_m_nblock,
count_grid_desc_m_nblock,
gamma_grid_desc_n,
beta_grid_desc_n,
numMeanVarCountBlockTileIteration_N,
......@@ -315,7 +323,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
Number<NumDTensor>{});
}
template <typename LayOut, typename DoPads, index_t MPerTile, index_t NPerTile>
template <typename DoPads, index_t MPerTile, index_t NPerTile>
static auto MakeMeanVarDescriptor_M_N(index_t M, index_t N)
{
const auto grid_desc_m_n =
......@@ -323,6 +331,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
return PadTensorDescriptor(grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
}
template <typename DoPads, index_t MPerTile, index_t NPerTile>
static auto MakeCountDescriptor_M_N(index_t M, index_t N)
{
const auto grid_desc_m_n =
make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, I1));
return PadTensorDescriptor(grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
}
template <index_t XPerTile>
static auto MakeDescriptor_X(index_t X)
{
......@@ -335,15 +351,22 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
// We have to separate mean var descriptor for gemm and layernorm bacause of different grid
// layout(different padding)
using GemmMeanVarCountGridDesc_M_NBlock = decltype(
MakeMeanVarDescriptor_M_N<HLayout, Sequence<true, false>, MPerBlock, NPerBlock>(1, 1));
using GemmMeanVarGridDesc_M_NBlock =
decltype(MakeMeanVarDescriptor_M_N<Sequence<true, false>, MPerBlock, NPerBlock>(1, 1));
using GemmCountGridDesc_M_NBlock =
decltype(MakeCountDescriptor_M_N<Sequence<true, false>, MPerBlock, NPerBlock>(1, 1));
using LayernormMeanVarCountGridDesc_M_NBlock =
decltype(MakeMeanVarDescriptor_M_N<HLayout,
Sequence<true, true>,
using LayernormMeanVarGridDesc_M_NBlock =
decltype(MakeMeanVarDescriptor_M_N<Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(1, 1));
using LayernormCountGridDesc_M_NBlock =
decltype(MakeCountDescriptor_M_N<Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(1, 1));
using GammaBetaGridDesc_N = decltype(MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(1));
using EHGridDesc_M_N = decltype(MakeEHGridDescriptor_M_N<HLayout>(1, 1, 1));
......@@ -363,7 +386,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K,
DsGridDesc_M_N,
EHGridDesc_M_N,
GemmMeanVarCountGridDesc_M_NBlock,
GemmMeanVarGridDesc_M_NBlock,
GemmCountGridDesc_M_NBlock,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
......@@ -408,7 +432,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BetaDataType,
AccDataType,
EHGridDesc_M_N,
LayernormMeanVarCountGridDesc_M_NBlock,
LayernormMeanVarGridDesc_M_NBlock,
LayernormCountGridDesc_M_NBlock,
GammaBetaGridDesc_N,
HElementwiseOperation,
BlockSize,
......@@ -456,8 +481,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEHGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)},
gemm_mean_var_count_grid_desc_m_nblock_{},
layernorm_mean_var_count_grid_desc_m_nblock_{},
gemm_mean_var_grid_desc_m_nblock_{},
gemm_count_grid_desc_m_nblock_{},
layernorm_mean_var_grid_desc_m_nblock_{},
layernorm_count_grid_desc_m_nblock_{},
gamma_grid_desc_n_{
DeviceOp::MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(NRaw)},
beta_grid_desc_n_{
......@@ -478,17 +505,26 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)},
epsilon_{epsilon}
{
gemm_mean_var_count_grid_desc_m_nblock_ = DeviceOp::
MakeMeanVarDescriptor_M_N<HLayout, Sequence<true, false>, MPerBlock, NPerBlock>(
gemm_mean_var_grid_desc_m_nblock_ =
DeviceOp::MakeMeanVarDescriptor_M_N<Sequence<true, false>, MPerBlock, NPerBlock>(
MRaw, gemm_nblock_);
layernorm_mean_var_count_grid_desc_m_nblock_ =
DeviceOp::MakeMeanVarDescriptor_M_N<HLayout,
Sequence<true, true>,
gemm_count_grid_desc_m_nblock_ =
DeviceOp::MakeCountDescriptor_M_N<Sequence<true, false>, MPerBlock, NPerBlock>(
MRaw, gemm_nblock_);
layernorm_mean_var_grid_desc_m_nblock_ =
DeviceOp::MakeMeanVarDescriptor_M_N<Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(
MRaw, gemm_nblock_);
layernorm_count_grid_desc_m_nblock_ =
DeviceOp::MakeCountDescriptor_M_N<Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(MRaw,
gemm_nblock_);
// populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
......@@ -517,9 +553,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
GridwiseGemmWelford::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
mean_var_count_grid_desc_mblock_mperblock_nblock_ =
gemm_mean_var_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
gemm_mean_var_count_grid_desc_m_nblock_);
gemm_mean_var_grid_desc_m_nblock_);
gemm_count_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
gemm_count_grid_desc_m_nblock_);
}
}
......@@ -551,8 +591,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EHGridDesc_M_N e_grid_desc_m_n_;
GemmMeanVarCountGridDesc_M_NBlock gemm_mean_var_count_grid_desc_m_nblock_;
LayernormMeanVarCountGridDesc_M_NBlock layernorm_mean_var_count_grid_desc_m_nblock_;
GemmMeanVarGridDesc_M_NBlock gemm_mean_var_grid_desc_m_nblock_;
GemmCountGridDesc_M_NBlock gemm_count_grid_desc_m_nblock_;
LayernormMeanVarGridDesc_M_NBlock layernorm_mean_var_grid_desc_m_nblock_;
LayernormCountGridDesc_M_NBlock layernorm_count_grid_desc_m_nblock_;
GammaBetaGridDesc_N gamma_grid_desc_n_;
GammaBetaGridDesc_N beta_grid_desc_n_;
EHGridDesc_M_N h_grid_desc_m_n_;
......@@ -564,8 +606,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemmWelford::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemmWelford::MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
mean_var_count_grid_desc_mblock_mperblock_nblock_;
typename GridwiseGemmWelford::MeanVarGridDescriptor_MBlock_MPerBlock_NBlock
gemm_mean_var_grid_desc_mblock_mperblock_nblock_;
typename GridwiseGemmWelford::CountGridDescriptor_MBlock_MPerBlock_NBlock
gemm_count_grid_desc_mblock_mperblock_nblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
......@@ -628,8 +672,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemmWelford::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemmWelford::
MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock,
typename GridwiseGemmWelford::MeanVarGridDescriptor_MBlock_MPerBlock_NBlock,
typename GridwiseGemmWelford::CountGridDescriptor_MBlock_MPerBlock_NBlock,
typename GridwiseGemmWelford::DefaultBlock2ETileMap,
has_main_loop>;
......@@ -643,7 +687,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BetaDataType,
AccDataType,
EHGridDesc_M_N,
LayernormMeanVarCountGridDesc_M_NBlock,
LayernormMeanVarGridDesc_M_NBlock,
LayernormCountGridDesc_M_NBlock,
GammaBetaGridDesc_N,
HElementwiseOperation>;
......@@ -667,7 +712,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.mean_var_count_grid_desc_mblock_mperblock_nblock_,
arg.gemm_mean_var_grid_desc_mblock_mperblock_nblock_,
arg.gemm_count_grid_desc_mblock_mperblock_nblock_,
arg.block_2_etile_map_,
arg.NRaw_);
......@@ -694,7 +740,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.p_h_grid_,
arg.e_grid_desc_m_n_,
arg.h_grid_desc_m_n_,
arg.layernorm_mean_var_count_grid_desc_m_nblock_,
arg.layernorm_mean_var_grid_desc_m_nblock_,
arg.layernorm_count_grid_desc_m_nblock_,
arg.gamma_grid_desc_n_,
arg.beta_grid_desc_n_,
numMeanVarCountBlockTileIteration_N,
......@@ -738,7 +785,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
workspace_size += gemm_welford_size * sizeof(VarDataType) + 64;
// workspace for welford intermediate count
workspace_size += gemm_welford_size * sizeof(int32_t) + 64;
workspace_size += pArg_->gemm_nblock_ * sizeof(int32_t) + 64;
return (workspace_size);
};
......
......@@ -47,7 +47,8 @@ template <typename ABDataType,
typename BGridDesc_N_K,
typename DsGridDesc_M_N,
typename EGridDesc_M_N,
typename MeanVarCountGridDesc_M_NBlock,
typename MeanVarGridDesc_M_NBlock,
typename CountGridDesc_M_NBlock,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
......@@ -347,8 +348,10 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(MeanVarCountGridDesc_M_NBlock{}))>;
using MeanVarGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(MeanVarGridDesc_M_NBlock{}))>;
using CountGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(CountGridDesc_M_NBlock{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
......@@ -361,27 +364,29 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename Block2ETileMap>
__device__ static void Run(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
MeanDataType* __restrict__ p_welford_mean_grid,
VarDataType* __restrict__ p_welford_var_grid,
int32_t* __restrict__ p_welford_count,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
const MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock&
mean_var_count_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap& block_2_etile_map,
index_t NRaw)
__device__ static void
Run(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
MeanDataType* __restrict__ p_welford_mean_grid,
VarDataType* __restrict__ p_welford_var_grid,
int32_t* __restrict__ p_welford_count,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock&
mean_var_grid_desc_mblock_mperblock_nblock,
const CountGridDescriptor_MBlock_MPerBlock_NBlock& count_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap& block_2_etile_map,
index_t NRaw)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
......@@ -401,16 +406,13 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto mean_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_mean_grid,
mean_var_count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
p_welford_mean_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
auto var_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_var_grid,
mean_var_count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
p_welford_var_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
auto welford_count_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_count,
mean_var_count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
p_welford_count, count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_work_idx =
......@@ -880,7 +882,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
Array<welford_count_vgpr_type, num_shuffleM> welford_count_thread_bufs;
int max_count = PostShuffleThreadSliceSize_N * num_shuffleN;
const auto nblock = mean_var_count_grid_desc_mblock_mperblock_nblock.GetLength(I2);
const auto nblock = mean_var_grid_desc_mblock_mperblock_nblock.GetLength(I2);
// tail block
if(block_work_idx[I1] % nblock == nblock - 1)
......@@ -1038,7 +1040,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
AccDataType,
MeanDataType,
decltype(thread_welford_desc_I_m_I),
decltype(mean_var_count_grid_desc_mblock_mperblock_nblock),
decltype(mean_var_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>,
......@@ -1046,7 +1048,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
1,
InMemoryDataOperationEnum::Set,
1,
false>{mean_var_count_grid_desc_mblock_mperblock_nblock,
false>{mean_var_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i +
post_shuffle_thread_data_idx_begin[I0], // mperblock
......@@ -1057,7 +1059,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
AccDataType,
VarDataType,
decltype(thread_welford_desc_I_m_I),
decltype(mean_var_count_grid_desc_mblock_mperblock_nblock),
decltype(mean_var_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>,
......@@ -1065,7 +1067,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
1,
InMemoryDataOperationEnum::Set,
1,
false>{mean_var_count_grid_desc_mblock_mperblock_nblock,
false>{mean_var_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i +
post_shuffle_thread_data_idx_begin[I0], // mperblock
......@@ -1076,7 +1078,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
int32_t,
int32_t,
decltype(thread_welford_desc_I_m_I),
decltype(mean_var_count_grid_desc_mblock_mperblock_nblock),
decltype(count_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>,
......@@ -1084,32 +1086,30 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
1,
InMemoryDataOperationEnum::Set,
1,
false>{mean_var_count_grid_desc_mblock_mperblock_nblock,
false>{count_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i +
post_shuffle_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]), // nblock
tensor_operation::element_wise::PassThrough{}};
mean_thread_copy_vgpr_to_global.Run(
thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
mean_thread_buf,
mean_var_count_grid_desc_mblock_mperblock_nblock,
mean_grid_buf);
mean_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
mean_thread_buf,
mean_var_grid_desc_mblock_mperblock_nblock,
mean_grid_buf);
var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
var_thread_buf,
mean_var_count_grid_desc_mblock_mperblock_nblock,
mean_var_grid_desc_mblock_mperblock_nblock,
var_grid_buf);
count_thread_copy_vgpr_to_global.Run(
thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
count_thread_buf,
mean_var_count_grid_desc_mblock_mperblock_nblock,
welford_count_grid_buf);
count_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
count_thread_buf,
count_grid_desc_mblock_mperblock_nblock,
welford_count_grid_buf);
});
} // shuffle C + Ds + welford + write out
......
......@@ -27,7 +27,8 @@ template <typename EDataType,
typename BetaDataType,
typename ComputeDataType,
typename EHGridDesc_M_N,
typename MeanVarCountGridDesc_M_NBlock,
typename MeanVarGridDesc_M_NBlock,
typename CountGridDesc_M_NBlock,
typename GammaBetaGridDesc_N,
typename HElementwiseOperation,
index_t BlockSize,
......@@ -95,7 +96,8 @@ struct GridwiseWelfordSecondHalfLayernorm2d
HDataType* __restrict__ p_h_grid,
const EHGridDesc_M_N& e_grid_desc_m_n,
const EHGridDesc_M_N& h_grid_desc_m_n,
const MeanVarCountGridDesc_M_NBlock& mean_var_count_grid_desc_m_n,
const MeanVarGridDesc_M_NBlock& mean_var_grid_desc_m_n,
const CountGridDesc_M_NBlock& count_grid_desc_m_n,
const GammaBetaGridDesc_N& gamma_grid_desc_n,
const GammaBetaGridDesc_N& beta_grid_desc_n,
index_t numMeanVarCountBlockTileIteration_N,
......@@ -116,13 +118,13 @@ struct GridwiseWelfordSecondHalfLayernorm2d
p_e_grid, e_grid_desc_m_n.GetElementSpaceSize());
const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_mean_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
p_in_welford_mean_grid, mean_var_grid_desc_m_n.GetElementSpaceSize());
const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_var_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
p_in_welford_var_grid, mean_var_grid_desc_m_n.GetElementSpaceSize());
const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_count_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
p_in_welford_count_grid, count_grid_desc_m_n.GetElementSpaceSize());
const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_gamma_grid, gamma_grid_desc_n.GetElementSpaceSize());
......@@ -173,7 +175,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
auto threadwise_mean_load_m_nblock =
ThreadwiseTensorSliceTransfer_v2<MeanDataType,
ComputeDataType,
MeanVarCountGridDesc_M_NBlock,
MeanVarGridDesc_M_NBlock,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
ThreadBufferDimAccessOrder,
......@@ -181,7 +183,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1,
1,
true>(
mean_var_count_grid_desc_m_n,
mean_var_grid_desc_m_n,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id));
......@@ -189,7 +191,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
auto threadwise_var_load_m_nblock =
ThreadwiseTensorSliceTransfer_v2<VarDataType,
ComputeDataType,
MeanVarCountGridDesc_M_NBlock,
MeanVarGridDesc_M_NBlock,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
ThreadBufferDimAccessOrder,
......@@ -197,7 +199,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1,
1,
true>(
mean_var_count_grid_desc_m_n,
mean_var_grid_desc_m_n,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id));
......@@ -205,7 +207,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
auto threadwise_count_load_m_nblock =
ThreadwiseTensorSliceTransfer_v2<int32_t,
int32_t,
MeanVarCountGridDesc_M_NBlock,
CountGridDesc_M_NBlock,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
ThreadBufferDimAccessOrder,
......@@ -213,7 +215,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1,
1,
true>(
mean_var_count_grid_desc_m_n,
count_grid_desc_m_n,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id));
......@@ -292,19 +294,19 @@ struct GridwiseWelfordSecondHalfLayernorm2d
for(index_t reducedTiles = 0; reducedTiles < numMeanVarCountBlockTileIteration_N;
++reducedTiles)
{
threadwise_mean_load_m_nblock.Run(mean_var_count_grid_desc_m_n,
threadwise_mean_load_m_nblock.Run(mean_var_grid_desc_m_n,
welford_mean_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_mean_thread_buf);
threadwise_var_load_m_nblock.Run(mean_var_count_grid_desc_m_n,
threadwise_var_load_m_nblock.Run(mean_var_grid_desc_m_n,
welford_var_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_var_thread_buf);
threadwise_count_load_m_nblock.Run(mean_var_count_grid_desc_m_n,
threadwise_count_load_m_nblock.Run(count_grid_desc_m_n,
welford_count_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
......@@ -317,11 +319,11 @@ struct GridwiseWelfordSecondHalfLayernorm2d
welford_var_thread_buf,
welford_count_thread_buf);
threadwise_mean_load_m_nblock.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
threadwise_mean_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_n,
mean_var_count_thread_copy_step_m_n);
threadwise_var_load_m_nblock.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
threadwise_var_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_n,
mean_var_count_thread_copy_step_m_n);
threadwise_count_load_m_nblock.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
threadwise_count_load_m_nblock.MoveSrcSliceWindow(count_grid_desc_m_n,
mean_var_count_thread_copy_step_m_n);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment