Commit 3df07c27 authored by rocking's avatar rocking
Browse files

Use 1D global memory for count

parent 39dedce7
...@@ -33,7 +33,8 @@ template <typename GridwiseGemmWelford, ...@@ -33,7 +33,8 @@ template <typename GridwiseGemmWelford,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock, typename MeanVarGridDescriptor_MBlock_MPerBlock_NBlock,
typename CountGridDescriptor_MBlock_MPerBlock_NBlock,
typename Block2ETileMap, typename Block2ETileMap,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
...@@ -57,8 +58,10 @@ __global__ void ...@@ -57,8 +58,10 @@ __global__ void
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
const MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock
mean_var_count_grid_desc_mblock_mperblock_nblock, mean_var_grid_desc_mblock_mperblock_nblock,
const CountGridDescriptor_MBlock_MPerBlock_NBlock
count_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap block_2_etile_map, const Block2ETileMap block_2_etile_map,
index_t NRaw) index_t NRaw)
{ {
...@@ -81,7 +84,8 @@ __global__ void ...@@ -81,7 +84,8 @@ __global__ void
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
mean_var_count_grid_desc_mblock_mperblock_nblock, mean_var_grid_desc_mblock_mperblock_nblock,
count_grid_desc_mblock_mperblock_nblock,
block_2_etile_map, block_2_etile_map,
NRaw); NRaw);
#else #else
...@@ -99,7 +103,8 @@ __global__ void ...@@ -99,7 +103,8 @@ __global__ void
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = mean_var_count_grid_desc_mblock_mperblock_nblock; ignore = mean_var_grid_desc_mblock_mperblock_nblock;
ignore = count_grid_desc_mblock_mperblock_nblock;
ignore = block_2_etile_map; ignore = block_2_etile_map;
ignore = NRaw; ignore = NRaw;
#endif #endif
...@@ -114,7 +119,8 @@ template <typename GridwiseWelfordLayernorm, ...@@ -114,7 +119,8 @@ template <typename GridwiseWelfordLayernorm,
typename BetaDataType, typename BetaDataType,
typename ComputeDataType, typename ComputeDataType,
typename EHGridDesc_M_N, typename EHGridDesc_M_N,
typename LayernormMeanVarCountGridDesc_M_NBlock, typename LayernormMeanVarGridDesc_M_NBlock,
typename LayernormCountGridDesc_M_NBlock,
typename GammaBetaGridDesc_N, typename GammaBetaGridDesc_N,
typename HElementwiseOperation> typename HElementwiseOperation>
__global__ void __global__ void
...@@ -131,7 +137,8 @@ __global__ void ...@@ -131,7 +137,8 @@ __global__ void
HDataType* __restrict__ p_h_grid, HDataType* __restrict__ p_h_grid,
const EHGridDesc_M_N e_grid_desc_m_n, const EHGridDesc_M_N e_grid_desc_m_n,
const EHGridDesc_M_N h_grid_desc_m_n, const EHGridDesc_M_N h_grid_desc_m_n,
const LayernormMeanVarCountGridDesc_M_NBlock mean_var_count_grid_desc_m_nblock, const LayernormMeanVarGridDesc_M_NBlock mean_var_grid_desc_m_nblock,
const LayernormCountGridDesc_M_NBlock count_grid_desc_m_nblock,
const GammaBetaGridDesc_N gamma_grid_desc_n, const GammaBetaGridDesc_N gamma_grid_desc_n,
const GammaBetaGridDesc_N beta_grid_desc_n, const GammaBetaGridDesc_N beta_grid_desc_n,
index_t numMeanVarCountBlockTileIteration_N, index_t numMeanVarCountBlockTileIteration_N,
...@@ -148,7 +155,8 @@ __global__ void ...@@ -148,7 +155,8 @@ __global__ void
p_h_grid, p_h_grid,
e_grid_desc_m_n, e_grid_desc_m_n,
h_grid_desc_m_n, h_grid_desc_m_n,
mean_var_count_grid_desc_m_nblock, mean_var_grid_desc_m_nblock,
count_grid_desc_m_nblock,
gamma_grid_desc_n, gamma_grid_desc_n,
beta_grid_desc_n, beta_grid_desc_n,
numMeanVarCountBlockTileIteration_N, numMeanVarCountBlockTileIteration_N,
...@@ -315,7 +323,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -315,7 +323,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
template <typename LayOut, typename DoPads, index_t MPerTile, index_t NPerTile> template <typename DoPads, index_t MPerTile, index_t NPerTile>
static auto MakeMeanVarDescriptor_M_N(index_t M, index_t N) static auto MakeMeanVarDescriptor_M_N(index_t M, index_t N)
{ {
const auto grid_desc_m_n = const auto grid_desc_m_n =
...@@ -323,6 +331,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -323,6 +331,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
return PadTensorDescriptor(grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{}); return PadTensorDescriptor(grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
} }
template <typename DoPads, index_t MPerTile, index_t NPerTile>
static auto MakeCountDescriptor_M_N(index_t M, index_t N)
{
const auto grid_desc_m_n =
make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, I1));
return PadTensorDescriptor(grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
}
template <index_t XPerTile> template <index_t XPerTile>
static auto MakeDescriptor_X(index_t X) static auto MakeDescriptor_X(index_t X)
{ {
...@@ -335,15 +351,22 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -335,15 +351,22 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
// We have to separate mean var descriptor for gemm and layernorm bacause of different grid // We have to separate mean var descriptor for gemm and layernorm bacause of different grid
// layout(different padding) // layout(different padding)
using GemmMeanVarCountGridDesc_M_NBlock = decltype( using GemmMeanVarGridDesc_M_NBlock =
MakeMeanVarDescriptor_M_N<HLayout, Sequence<true, false>, MPerBlock, NPerBlock>(1, 1)); decltype(MakeMeanVarDescriptor_M_N<Sequence<true, false>, MPerBlock, NPerBlock>(1, 1));
using GemmCountGridDesc_M_NBlock =
decltype(MakeCountDescriptor_M_N<Sequence<true, false>, MPerBlock, NPerBlock>(1, 1));
using LayernormMeanVarCountGridDesc_M_NBlock = using LayernormMeanVarGridDesc_M_NBlock =
decltype(MakeMeanVarDescriptor_M_N<HLayout, decltype(MakeMeanVarDescriptor_M_N<Sequence<true, true>,
Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0), LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(1, 1)); LayernormBlockTileSize_M_N::At(1)>(1, 1));
using LayernormCountGridDesc_M_NBlock =
decltype(MakeCountDescriptor_M_N<Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(1, 1));
using GammaBetaGridDesc_N = decltype(MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(1)); using GammaBetaGridDesc_N = decltype(MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(1));
using EHGridDesc_M_N = decltype(MakeEHGridDescriptor_M_N<HLayout>(1, 1, 1)); using EHGridDesc_M_N = decltype(MakeEHGridDescriptor_M_N<HLayout>(1, 1, 1));
...@@ -363,7 +386,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -363,7 +386,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K, BGridDesc_N_K,
DsGridDesc_M_N, DsGridDesc_M_N,
EHGridDesc_M_N, EHGridDesc_M_N,
GemmMeanVarCountGridDesc_M_NBlock, GemmMeanVarGridDesc_M_NBlock,
GemmCountGridDesc_M_NBlock,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -408,7 +432,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -408,7 +432,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BetaDataType, BetaDataType,
AccDataType, AccDataType,
EHGridDesc_M_N, EHGridDesc_M_N,
LayernormMeanVarCountGridDesc_M_NBlock, LayernormMeanVarGridDesc_M_NBlock,
LayernormCountGridDesc_M_NBlock,
GammaBetaGridDesc_N, GammaBetaGridDesc_N,
HElementwiseOperation, HElementwiseOperation,
BlockSize, BlockSize,
...@@ -456,8 +481,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -456,8 +481,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEHGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)}, e_grid_desc_m_n_{DeviceOp::MakeEHGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)},
gemm_mean_var_count_grid_desc_m_nblock_{}, gemm_mean_var_grid_desc_m_nblock_{},
layernorm_mean_var_count_grid_desc_m_nblock_{}, gemm_count_grid_desc_m_nblock_{},
layernorm_mean_var_grid_desc_m_nblock_{},
layernorm_count_grid_desc_m_nblock_{},
gamma_grid_desc_n_{ gamma_grid_desc_n_{
DeviceOp::MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(NRaw)}, DeviceOp::MakeDescriptor_X<LayernormBlockTileSize_M_N::At(1)>(NRaw)},
beta_grid_desc_n_{ beta_grid_desc_n_{
...@@ -478,17 +505,26 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -478,17 +505,26 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)}, gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)},
epsilon_{epsilon} epsilon_{epsilon}
{ {
gemm_mean_var_count_grid_desc_m_nblock_ = DeviceOp:: gemm_mean_var_grid_desc_m_nblock_ =
MakeMeanVarDescriptor_M_N<HLayout, Sequence<true, false>, MPerBlock, NPerBlock>( DeviceOp::MakeMeanVarDescriptor_M_N<Sequence<true, false>, MPerBlock, NPerBlock>(
MRaw, gemm_nblock_); MRaw, gemm_nblock_);
layernorm_mean_var_count_grid_desc_m_nblock_ = gemm_count_grid_desc_m_nblock_ =
DeviceOp::MakeMeanVarDescriptor_M_N<HLayout, DeviceOp::MakeCountDescriptor_M_N<Sequence<true, false>, MPerBlock, NPerBlock>(
Sequence<true, true>, MRaw, gemm_nblock_);
layernorm_mean_var_grid_desc_m_nblock_ =
DeviceOp::MakeMeanVarDescriptor_M_N<Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0), LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>( LayernormBlockTileSize_M_N::At(1)>(
MRaw, gemm_nblock_); MRaw, gemm_nblock_);
layernorm_count_grid_desc_m_nblock_ =
DeviceOp::MakeCountDescriptor_M_N<Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(MRaw,
gemm_nblock_);
// populate pointer, desc for Ds // populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
...@@ -517,9 +553,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -517,9 +553,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
GridwiseGemmWelford::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemmWelford::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_); e_grid_desc_m_n_);
mean_var_count_grid_desc_mblock_mperblock_nblock_ = gemm_mean_var_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock( GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
gemm_mean_var_count_grid_desc_m_nblock_); gemm_mean_var_grid_desc_m_nblock_);
gemm_count_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
gemm_count_grid_desc_m_nblock_);
} }
} }
...@@ -551,8 +591,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -551,8 +591,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K b_grid_desc_n_k_; BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_; DsGridDesc_M_N ds_grid_desc_m_n_;
EHGridDesc_M_N e_grid_desc_m_n_; EHGridDesc_M_N e_grid_desc_m_n_;
GemmMeanVarCountGridDesc_M_NBlock gemm_mean_var_count_grid_desc_m_nblock_; GemmMeanVarGridDesc_M_NBlock gemm_mean_var_grid_desc_m_nblock_;
LayernormMeanVarCountGridDesc_M_NBlock layernorm_mean_var_count_grid_desc_m_nblock_; GemmCountGridDesc_M_NBlock gemm_count_grid_desc_m_nblock_;
LayernormMeanVarGridDesc_M_NBlock layernorm_mean_var_grid_desc_m_nblock_;
LayernormCountGridDesc_M_NBlock layernorm_count_grid_desc_m_nblock_;
GammaBetaGridDesc_N gamma_grid_desc_n_; GammaBetaGridDesc_N gamma_grid_desc_n_;
GammaBetaGridDesc_N beta_grid_desc_n_; GammaBetaGridDesc_N beta_grid_desc_n_;
EHGridDesc_M_N h_grid_desc_m_n_; EHGridDesc_M_N h_grid_desc_m_n_;
...@@ -564,8 +606,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -564,8 +606,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
ds_grid_desc_mblock_mperblock_nblock_nperblock_; ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemmWelford::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemmWelford::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_; e_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemmWelford::MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock typename GridwiseGemmWelford::MeanVarGridDescriptor_MBlock_MPerBlock_NBlock
mean_var_count_grid_desc_mblock_mperblock_nblock_; gemm_mean_var_grid_desc_mblock_mperblock_nblock_;
typename GridwiseGemmWelford::CountGridDescriptor_MBlock_MPerBlock_NBlock
gemm_count_grid_desc_mblock_mperblock_nblock_;
// block-to-e-tile map // block-to-e-tile map
Block2ETileMap block_2_etile_map_; Block2ETileMap block_2_etile_map_;
...@@ -628,8 +672,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -628,8 +672,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemmWelford:: typename GridwiseGemmWelford::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemmWelford:: typename GridwiseGemmWelford::MeanVarGridDescriptor_MBlock_MPerBlock_NBlock,
MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock, typename GridwiseGemmWelford::CountGridDescriptor_MBlock_MPerBlock_NBlock,
typename GridwiseGemmWelford::DefaultBlock2ETileMap, typename GridwiseGemmWelford::DefaultBlock2ETileMap,
has_main_loop>; has_main_loop>;
...@@ -643,7 +687,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -643,7 +687,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BetaDataType, BetaDataType,
AccDataType, AccDataType,
EHGridDesc_M_N, EHGridDesc_M_N,
LayernormMeanVarCountGridDesc_M_NBlock, LayernormMeanVarGridDesc_M_NBlock,
LayernormCountGridDesc_M_NBlock,
GammaBetaGridDesc_N, GammaBetaGridDesc_N,
HElementwiseOperation>; HElementwiseOperation>;
...@@ -667,7 +712,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -667,7 +712,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.mean_var_count_grid_desc_mblock_mperblock_nblock_, arg.gemm_mean_var_grid_desc_mblock_mperblock_nblock_,
arg.gemm_count_grid_desc_mblock_mperblock_nblock_,
arg.block_2_etile_map_, arg.block_2_etile_map_,
arg.NRaw_); arg.NRaw_);
...@@ -694,7 +740,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -694,7 +740,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.p_h_grid_, arg.p_h_grid_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
arg.h_grid_desc_m_n_, arg.h_grid_desc_m_n_,
arg.layernorm_mean_var_count_grid_desc_m_nblock_, arg.layernorm_mean_var_grid_desc_m_nblock_,
arg.layernorm_count_grid_desc_m_nblock_,
arg.gamma_grid_desc_n_, arg.gamma_grid_desc_n_,
arg.beta_grid_desc_n_, arg.beta_grid_desc_n_,
numMeanVarCountBlockTileIteration_N, numMeanVarCountBlockTileIteration_N,
...@@ -738,7 +785,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -738,7 +785,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
workspace_size += gemm_welford_size * sizeof(VarDataType) + 64; workspace_size += gemm_welford_size * sizeof(VarDataType) + 64;
// workspace for welford intermediate count // workspace for welford intermediate count
workspace_size += gemm_welford_size * sizeof(int32_t) + 64; workspace_size += pArg_->gemm_nblock_ * sizeof(int32_t) + 64;
return (workspace_size); return (workspace_size);
}; };
......
...@@ -47,7 +47,8 @@ template <typename ABDataType, ...@@ -47,7 +47,8 @@ template <typename ABDataType,
typename BGridDesc_N_K, typename BGridDesc_N_K,
typename DsGridDesc_M_N, typename DsGridDesc_M_N,
typename EGridDesc_M_N, typename EGridDesc_M_N,
typename MeanVarCountGridDesc_M_NBlock, typename MeanVarGridDesc_M_NBlock,
typename CountGridDesc_M_NBlock,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -347,8 +348,10 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -347,8 +348,10 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype( using MeanVarGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(MeanVarCountGridDesc_M_NBlock{}))>; MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(MeanVarGridDesc_M_NBlock{}))>;
using CountGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(CountGridDesc_M_NBlock{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
...@@ -361,27 +364,29 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -361,27 +364,29 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename Block2ETileMap> typename Block2ETileMap>
__device__ static void Run(const ABDataType* __restrict__ p_a_grid, __device__ static void
const ABDataType* __restrict__ p_b_grid, Run(const ABDataType* __restrict__ p_a_grid,
DsGridPointer p_ds_grid, const ABDataType* __restrict__ p_b_grid,
EDataType* __restrict__ p_e_grid, DsGridPointer p_ds_grid,
MeanDataType* __restrict__ p_welford_mean_grid, EDataType* __restrict__ p_e_grid,
VarDataType* __restrict__ p_welford_var_grid, MeanDataType* __restrict__ p_welford_mean_grid,
int32_t* __restrict__ p_welford_count, VarDataType* __restrict__ p_welford_var_grid,
void* __restrict__ p_shared, int32_t* __restrict__ p_welford_count,
const AElementwiseOperation& a_element_op, void* __restrict__ p_shared,
const BElementwiseOperation& b_element_op, const AElementwiseOperation& a_element_op,
const CDEElementwiseOperation& cde_element_op, const BElementwiseOperation& b_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const CDEElementwiseOperation& cde_element_op,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
const MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock& e_grid_desc_mblock_mperblock_nblock_nperblock,
mean_var_count_grid_desc_mblock_mperblock_nblock, const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock&
const Block2ETileMap& block_2_etile_map, mean_var_grid_desc_mblock_mperblock_nblock,
index_t NRaw) const CountGridDescriptor_MBlock_MPerBlock_NBlock& count_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap& block_2_etile_map,
index_t NRaw)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
...@@ -401,16 +406,13 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -401,16 +406,13 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto mean_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto mean_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_mean_grid, p_welford_mean_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
mean_var_count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
auto var_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto var_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_var_grid, p_welford_var_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
mean_var_count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
auto welford_count_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto welford_count_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_count, p_welford_count, count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
mean_var_count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx =
...@@ -880,7 +882,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -880,7 +882,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
Array<welford_count_vgpr_type, num_shuffleM> welford_count_thread_bufs; Array<welford_count_vgpr_type, num_shuffleM> welford_count_thread_bufs;
int max_count = PostShuffleThreadSliceSize_N * num_shuffleN; int max_count = PostShuffleThreadSliceSize_N * num_shuffleN;
const auto nblock = mean_var_count_grid_desc_mblock_mperblock_nblock.GetLength(I2); const auto nblock = mean_var_grid_desc_mblock_mperblock_nblock.GetLength(I2);
// tail block // tail block
if(block_work_idx[I1] % nblock == nblock - 1) if(block_work_idx[I1] % nblock == nblock - 1)
...@@ -1038,7 +1040,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -1038,7 +1040,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
AccDataType, AccDataType,
MeanDataType, MeanDataType,
decltype(thread_welford_desc_I_m_I), decltype(thread_welford_desc_I_m_I),
decltype(mean_var_count_grid_desc_mblock_mperblock_nblock), decltype(mean_var_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>, Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
...@@ -1046,7 +1048,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -1046,7 +1048,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
1, 1,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>{mean_var_count_grid_desc_mblock_mperblock_nblock, false>{mean_var_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i + shuffleMPerBlock * i +
post_shuffle_thread_data_idx_begin[I0], // mperblock post_shuffle_thread_data_idx_begin[I0], // mperblock
...@@ -1057,7 +1059,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -1057,7 +1059,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
AccDataType, AccDataType,
VarDataType, VarDataType,
decltype(thread_welford_desc_I_m_I), decltype(thread_welford_desc_I_m_I),
decltype(mean_var_count_grid_desc_mblock_mperblock_nblock), decltype(mean_var_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>, Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
...@@ -1065,7 +1067,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -1065,7 +1067,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
1, 1,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>{mean_var_count_grid_desc_mblock_mperblock_nblock, false>{mean_var_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i + shuffleMPerBlock * i +
post_shuffle_thread_data_idx_begin[I0], // mperblock post_shuffle_thread_data_idx_begin[I0], // mperblock
...@@ -1076,7 +1078,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -1076,7 +1078,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
int32_t, int32_t,
int32_t, int32_t,
decltype(thread_welford_desc_I_m_I), decltype(thread_welford_desc_I_m_I),
decltype(mean_var_count_grid_desc_mblock_mperblock_nblock), decltype(count_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>, Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
...@@ -1084,32 +1086,30 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -1084,32 +1086,30 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
1, 1,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>{mean_var_count_grid_desc_mblock_mperblock_nblock, false>{count_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i + shuffleMPerBlock * i +
post_shuffle_thread_data_idx_begin[I0], // mperblock post_shuffle_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]), // nblock block_work_idx[I1]), // nblock
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
mean_thread_copy_vgpr_to_global.Run( mean_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
thread_welford_desc_I_m_I, make_tuple(I0, I0, I0),
make_tuple(I0, I0, I0), mean_thread_buf,
mean_thread_buf, mean_var_grid_desc_mblock_mperblock_nblock,
mean_var_count_grid_desc_mblock_mperblock_nblock, mean_grid_buf);
mean_grid_buf);
var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I, var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
var_thread_buf, var_thread_buf,
mean_var_count_grid_desc_mblock_mperblock_nblock, mean_var_grid_desc_mblock_mperblock_nblock,
var_grid_buf); var_grid_buf);
count_thread_copy_vgpr_to_global.Run( count_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
thread_welford_desc_I_m_I, make_tuple(I0, I0, I0),
make_tuple(I0, I0, I0), count_thread_buf,
count_thread_buf, count_grid_desc_mblock_mperblock_nblock,
mean_var_count_grid_desc_mblock_mperblock_nblock, welford_count_grid_buf);
welford_count_grid_buf);
}); });
} // shuffle C + Ds + welford + write out } // shuffle C + Ds + welford + write out
......
...@@ -27,7 +27,8 @@ template <typename EDataType, ...@@ -27,7 +27,8 @@ template <typename EDataType,
typename BetaDataType, typename BetaDataType,
typename ComputeDataType, typename ComputeDataType,
typename EHGridDesc_M_N, typename EHGridDesc_M_N,
typename MeanVarCountGridDesc_M_NBlock, typename MeanVarGridDesc_M_NBlock,
typename CountGridDesc_M_NBlock,
typename GammaBetaGridDesc_N, typename GammaBetaGridDesc_N,
typename HElementwiseOperation, typename HElementwiseOperation,
index_t BlockSize, index_t BlockSize,
...@@ -95,7 +96,8 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -95,7 +96,8 @@ struct GridwiseWelfordSecondHalfLayernorm2d
HDataType* __restrict__ p_h_grid, HDataType* __restrict__ p_h_grid,
const EHGridDesc_M_N& e_grid_desc_m_n, const EHGridDesc_M_N& e_grid_desc_m_n,
const EHGridDesc_M_N& h_grid_desc_m_n, const EHGridDesc_M_N& h_grid_desc_m_n,
const MeanVarCountGridDesc_M_NBlock& mean_var_count_grid_desc_m_n, const MeanVarGridDesc_M_NBlock& mean_var_grid_desc_m_n,
const CountGridDesc_M_NBlock& count_grid_desc_m_n,
const GammaBetaGridDesc_N& gamma_grid_desc_n, const GammaBetaGridDesc_N& gamma_grid_desc_n,
const GammaBetaGridDesc_N& beta_grid_desc_n, const GammaBetaGridDesc_N& beta_grid_desc_n,
index_t numMeanVarCountBlockTileIteration_N, index_t numMeanVarCountBlockTileIteration_N,
...@@ -116,13 +118,13 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -116,13 +118,13 @@ struct GridwiseWelfordSecondHalfLayernorm2d
p_e_grid, e_grid_desc_m_n.GetElementSpaceSize()); p_e_grid, e_grid_desc_m_n.GetElementSpaceSize());
const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_mean_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize()); p_in_welford_mean_grid, mean_var_grid_desc_m_n.GetElementSpaceSize());
const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_var_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize()); p_in_welford_var_grid, mean_var_grid_desc_m_n.GetElementSpaceSize());
const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_count_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize()); p_in_welford_count_grid, count_grid_desc_m_n.GetElementSpaceSize());
const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_gamma_grid, gamma_grid_desc_n.GetElementSpaceSize()); p_gamma_grid, gamma_grid_desc_n.GetElementSpaceSize());
...@@ -173,7 +175,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -173,7 +175,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
auto threadwise_mean_load_m_nblock = auto threadwise_mean_load_m_nblock =
ThreadwiseTensorSliceTransfer_v2<MeanDataType, ThreadwiseTensorSliceTransfer_v2<MeanDataType,
ComputeDataType, ComputeDataType,
MeanVarCountGridDesc_M_NBlock, MeanVarGridDesc_M_NBlock,
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
ThreadBufferDimAccessOrder, ThreadBufferDimAccessOrder,
...@@ -181,7 +183,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -181,7 +183,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1, 1,
1, 1,
true>( true>(
mean_var_count_grid_desc_m_n, mean_var_grid_desc_m_n,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id)); thread_n_cluster_id));
...@@ -189,7 +191,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -189,7 +191,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
auto threadwise_var_load_m_nblock = auto threadwise_var_load_m_nblock =
ThreadwiseTensorSliceTransfer_v2<VarDataType, ThreadwiseTensorSliceTransfer_v2<VarDataType,
ComputeDataType, ComputeDataType,
MeanVarCountGridDesc_M_NBlock, MeanVarGridDesc_M_NBlock,
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
ThreadBufferDimAccessOrder, ThreadBufferDimAccessOrder,
...@@ -197,7 +199,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -197,7 +199,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1, 1,
1, 1,
true>( true>(
mean_var_count_grid_desc_m_n, mean_var_grid_desc_m_n,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id)); thread_n_cluster_id));
...@@ -205,7 +207,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -205,7 +207,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
auto threadwise_count_load_m_nblock = auto threadwise_count_load_m_nblock =
ThreadwiseTensorSliceTransfer_v2<int32_t, ThreadwiseTensorSliceTransfer_v2<int32_t,
int32_t, int32_t,
MeanVarCountGridDesc_M_NBlock, CountGridDesc_M_NBlock,
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
ThreadBufferDimAccessOrder, ThreadBufferDimAccessOrder,
...@@ -213,7 +215,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -213,7 +215,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1, 1,
1, 1,
true>( true>(
mean_var_count_grid_desc_m_n, count_grid_desc_m_n,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id)); thread_n_cluster_id));
...@@ -292,19 +294,19 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -292,19 +294,19 @@ struct GridwiseWelfordSecondHalfLayernorm2d
for(index_t reducedTiles = 0; reducedTiles < numMeanVarCountBlockTileIteration_N; for(index_t reducedTiles = 0; reducedTiles < numMeanVarCountBlockTileIteration_N;
++reducedTiles) ++reducedTiles)
{ {
threadwise_mean_load_m_nblock.Run(mean_var_count_grid_desc_m_n, threadwise_mean_load_m_nblock.Run(mean_var_grid_desc_m_n,
welford_mean_global_val_buf, welford_mean_global_val_buf,
thread_buffer_desc_m_1, thread_buffer_desc_m_1,
make_tuple(I0, I0), make_tuple(I0, I0),
in_welford_mean_thread_buf); in_welford_mean_thread_buf);
threadwise_var_load_m_nblock.Run(mean_var_count_grid_desc_m_n, threadwise_var_load_m_nblock.Run(mean_var_grid_desc_m_n,
welford_var_global_val_buf, welford_var_global_val_buf,
thread_buffer_desc_m_1, thread_buffer_desc_m_1,
make_tuple(I0, I0), make_tuple(I0, I0),
in_welford_var_thread_buf); in_welford_var_thread_buf);
threadwise_count_load_m_nblock.Run(mean_var_count_grid_desc_m_n, threadwise_count_load_m_nblock.Run(count_grid_desc_m_n,
welford_count_global_val_buf, welford_count_global_val_buf,
thread_buffer_desc_m_1, thread_buffer_desc_m_1,
make_tuple(I0, I0), make_tuple(I0, I0),
...@@ -317,11 +319,11 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -317,11 +319,11 @@ struct GridwiseWelfordSecondHalfLayernorm2d
welford_var_thread_buf, welford_var_thread_buf,
welford_count_thread_buf); welford_count_thread_buf);
threadwise_mean_load_m_nblock.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n, threadwise_mean_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_n,
mean_var_count_thread_copy_step_m_n); mean_var_count_thread_copy_step_m_n);
threadwise_var_load_m_nblock.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n, threadwise_var_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_n,
mean_var_count_thread_copy_step_m_n); mean_var_count_thread_copy_step_m_n);
threadwise_count_load_m_nblock.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n, threadwise_count_load_m_nblock.MoveSrcSliceWindow(count_grid_desc_m_n,
mean_var_count_thread_copy_step_m_n); mean_var_count_thread_copy_step_m_n);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment