Commit 003ec407 authored by rocking's avatar rocking
Browse files

Add welford count

parent b7f500f0
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp" // #include
// "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "device_base.hpp" #include "device_base.hpp"
...@@ -33,8 +34,7 @@ template <typename GridwiseGemmWelford, ...@@ -33,8 +34,7 @@ template <typename GridwiseGemmWelford,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename MeanGridDescriptor_MBlock_MPerBlock_NBlock, typename MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock,
typename VarGridDescriptor_MBlock_MPerBlock_NBlock,
typename Block2ETileMap, typename Block2ETileMap,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
...@@ -46,8 +46,9 @@ __global__ void ...@@ -46,8 +46,9 @@ __global__ void
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
MeanDataType* __restrict__ p_mean_grid, MeanDataType* __restrict__ p_welford_mean_grid,
VarDataType* __restrict__ p_var_grid, VarDataType* __restrict__ p_welford_var_grid,
int32_t* __restrict__ p_welford_count_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
...@@ -57,8 +58,8 @@ __global__ void ...@@ -57,8 +58,8 @@ __global__ void
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
const MeanGridDescriptor_MBlock_MPerBlock_NBlock mean_grid_desc_mblock_mperblock_nblock, const MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
const VarGridDescriptor_MBlock_MPerBlock_NBlock var_grid_desc_mblock_mperblock_nblock, mean_var_count_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap block_2_etile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
...@@ -69,8 +70,9 @@ __global__ void ...@@ -69,8 +70,9 @@ __global__ void
p_b_grid, p_b_grid,
p_ds_grid, p_ds_grid,
p_e_grid, p_e_grid,
p_mean_grid, p_welford_mean_grid,
p_var_grid, p_welford_var_grid,
p_welford_count_grid,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -79,16 +81,16 @@ __global__ void ...@@ -79,16 +81,16 @@ __global__ void
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
mean_grid_desc_mblock_mperblock_nblock, mean_var_count_grid_desc_mblock_mperblock_nblock,
var_grid_desc_mblock_mperblock_nblock,
block_2_etile_map); block_2_etile_map);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_ds_grid; ignore = p_ds_grid;
ignore = p_e_grid; ignore = p_e_grid;
ignore = p_mean_grid; ignore = p_welford_mean_grid;
ignore = p_var_grid; ignore = p_welford_var_grid;
ignore = p_welford_count_grid;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = cde_element_op; ignore = cde_element_op;
...@@ -96,29 +98,28 @@ __global__ void ...@@ -96,29 +98,28 @@ __global__ void
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = mean_grid_desc_mblock_mperblock_nblock; ignore = mean_var_count_grid_desc_mblock_mperblock_nblock;
ignore = var_grid_desc_mblock_mperblock_nblock;
ignore = block_2_etile_map; ignore = block_2_etile_map;
#endif #endif
} }
template <typename GridwiseWelfordLayernorm, // template <typename GridwiseWelfordLayernorm,
typename EDataType, // typename EDataType,
typename HDataType, // typename HDataType,
typename MeanDataType, // typename MeanDataType,
typename VarDataType> // typename VarDataType>
__global__ void // __global__ void
#if CK_USE_LAUNCH_BOUNDS // #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) // __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif // #endif
kernel_welford_layernorm2d_second_half(const EDataType* __restrict__ p_x_grid, // kernel_welford_layernorm2d_second_half(const EDataType* __restrict__ p_x_grid,
const MeanDataType* __restrict__ p_mean_grid, // const MeanDataType* __restrict__ p_mean_grid,
const VarDataType* __restrict__ p_var_grid, // const VarDataType* __restrict__ p_var_grid,
HDataType* __restrict__ p_y_grid, // HDataType* __restrict__ p_y_grid,
index_t blkgroup_size) // index_t blkgroup_size)
{ // {
GridwiseWelfordLayernorm::Run(p_x_grid, p_mean_grid, p_var_grid, p_y_grid, blkgroup_size); // // GridwiseWelfordLayernorm::Run(p_x_grid, p_mean_grid, p_var_grid, p_y_grid, blkgroup_size);
} // }
} // namespace ck } // namespace ck
...@@ -326,14 +327,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -326,14 +327,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
} }
}; };
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1)); using EGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using MeanVarGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1)); using MeanVarCountGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using GammaBetaGridDesc_N = decltype(MakeDescriptor_N(1)); using GammaBetaGridDesc_N = decltype(MakeDescriptor_N(1));
using MeanVarGridDesc_M = decltype(MakeDescriptor_M(1)); using MeanVarGridDesc_M = decltype(MakeDescriptor_M(1));
using HGridDesc_M_N = decltype(MakeGridDescriptor_M_N<HLayout>(1, 1, 1)); using HGridDesc_M_N = decltype(MakeGridDescriptor_M_N<HLayout>(1, 1, 1));
using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
...@@ -351,8 +352,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -351,8 +352,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K, BGridDesc_N_K,
DsGridDesc_M_N, DsGridDesc_M_N,
EGridDesc_M_N, EGridDesc_M_N,
MeanVarGridDesc_M_N, MeanVarCountGridDesc_M_N,
MeanVarGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -384,32 +384,31 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -384,32 +384,31 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
PostShuffleThreadClusterSize_M_N, PostShuffleThreadClusterSize_M_N,
PostShuffleScalarPerVector, PostShuffleScalarPerVector,
1,
LoopSched>; LoopSched>;
using Block2ETileMap = typename GridwiseGemmWelford::DefaultBlock2ETileMap; using Block2ETileMap = typename GridwiseGemmWelford::DefaultBlock2ETileMap;
using GridwiseWelfordLayernorm = // using GridwiseWelfordLayernorm =
GridwiseWelfordSecondHalfLayernorm2d<EDataType, // GridwiseWelfordSecondHalfLayernorm2d<EDataType,
HDataType, // HDataType,
MeanDataType, // MeanDataType,
VarDataType, // VarDataType,
AccDataType, // AccDataType,
HGridDesc_M_N, // HGridDesc_M_N,
MeanVarGridDesc_M_N, // MeanVarGridDesc_M_N,
GammaBetaGridDesc_N, // GammaBetaGridDesc_N,
MeanVarGridDesc_M, // MeanVarGridDesc_M,
BlockSize, // BlockSize,
LayernormThreadClusterSize_M_N::At(I0), // LayernormThreadClusterSize_M_N::At(I0),
LayernormThreadClusterSize_M_N::At(I1), // LayernormThreadClusterSize_M_N::At(I1),
LayernormThreadSliceSize_M_N::At(I0), // LayernormThreadSliceSize_M_N::At(I0),
LayernormThreadSliceSize_M_N::At(I1), // LayernormThreadSliceSize_M_N::At(I1),
LayernormESrcHDstVectorDim, // LayernormESrcHDstVectorDim,
LayernormESrcVectorSize, // LayernormESrcVectorSize,
LayernormHDstVectorSize, // LayernormHDstVectorSize,
LayernormGammaSrcVectorSize, // LayernormGammaSrcVectorSize,
LayernormBetaSrcVectorSize, // LayernormBetaSrcVectorSize,
LayernormMeanVarSrcDstVectorSize>; // LayernormMeanVarSrcDstVectorSize>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -436,8 +435,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -436,8 +435,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
p_b_grid_{static_cast<const BDataType*>(p_b_grid)}, p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{}, p_ds_grid_{},
p_e_grid_{nullptr}, p_e_grid_{nullptr},
p_mean_grid_{nullptr}, p_welford_mean_grid_{nullptr},
p_var_grid_{nullptr}, p_welford_var_grid_{nullptr},
p_welford_count_grid_{nullptr},
p_gamma_grid_{static_cast<const GammaDataType*>(p_gamma_grid)}, p_gamma_grid_{static_cast<const GammaDataType*>(p_gamma_grid)},
p_beta_grid_{static_cast<const BetaDataType*>(p_beta_grid)}, p_beta_grid_{static_cast<const BetaDataType*>(p_beta_grid)},
p_h_grid_{static_cast<HDataType*>(p_h_grid)}, p_h_grid_{static_cast<HDataType*>(p_h_grid)},
...@@ -445,8 +445,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -445,8 +445,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)}, e_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)},
mean_grid_desc_m_n_{}, mean_var_count_grid_desc_m_n_{},
var_grid_desc_m_n_{},
gamma_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)}, gamma_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
beta_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)}, beta_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
h_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)}, h_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
...@@ -458,16 +457,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -458,16 +457,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
blkGroupSize_{math::integer_divide_ceil(NRaw, NPerBlock)}, blkGroupSize_{math::integer_divide_ceil(NRaw, NPerBlock)},
epsilon_{epsilon} epsilon_{epsilon}
{ {
mean_grid_desc_m_n_ = mean_var_count_grid_desc_m_n_ =
DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, blkGroupSize_, blkGroupSize_);
var_grid_desc_m_n_ =
DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, blkGroupSize_, blkGroupSize_); DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, blkGroupSize_, blkGroupSize_);
hip_check_error(hipMalloc(&p_e_grid_, sizeof(EDataType) * MRaw * NRaw)); hip_check_error(hipMalloc(&p_e_grid_, sizeof(EDataType) * MRaw * NRaw));
int gemm_welford_size = MRaw * blkGroupSize_; int gemm_welford_size = MRaw * blkGroupSize_;
hip_check_error(hipMalloc(&p_mean_grid_, sizeof(MeanDataType) * gemm_welford_size)); hip_check_error(
hip_check_error(hipMalloc(&p_var_grid_, sizeof(VarDataType) * gemm_welford_size)); hipMalloc(&p_welford_mean_grid_, sizeof(MeanDataType) * gemm_welford_size));
hip_check_error(
hipMalloc(&p_welford_var_grid_, sizeof(VarDataType) * gemm_welford_size));
hip_check_error(hipMalloc(&p_welford_count_grid_, sizeof(int32_t) * gemm_welford_size));
// populate pointer, desc for Ds // populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
...@@ -487,8 +487,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -487,8 +487,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_, b_grid_desc_n_k_,
ds_grid_desc_m_n_, ds_grid_desc_m_n_,
e_grid_desc_m_n_, e_grid_desc_m_n_,
mean_grid_desc_m_n_, mean_var_count_grid_desc_m_n_,
var_grid_desc_m_n_,
block_2_etile_map_)) block_2_etile_map_))
{ {
ds_grid_desc_mblock_mperblock_nblock_nperblock_ = ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
...@@ -499,13 +498,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -499,13 +498,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
GridwiseGemmWelford::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemmWelford::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_); e_grid_desc_m_n_);
mean_grid_desc_mblock_mperblock_nblock_ = mean_var_count_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemmWelford::MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock( GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
mean_grid_desc_m_n_); mean_var_count_grid_desc_m_n_);
var_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemmWelford::MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(
var_grid_desc_m_n_);
} }
// TODO - H // TODO - H
...@@ -527,8 +522,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -527,8 +522,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
typename GridwiseGemmWelford::DsGridPointer p_ds_grid_; typename GridwiseGemmWelford::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_; EDataType* p_e_grid_;
MeanDataType* p_mean_grid_; // mean MeanDataType* p_welford_mean_grid_;
VarDataType* p_var_grid_; // variance * count VarDataType* p_welford_var_grid_;
int32_t* p_welford_count_grid_;
const GammaDataType* p_gamma_grid_; const GammaDataType* p_gamma_grid_;
const BetaDataType* p_beta_grid_; const BetaDataType* p_beta_grid_;
HDataType* p_h_grid_; HDataType* p_h_grid_;
...@@ -538,8 +534,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -538,8 +534,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K b_grid_desc_n_k_; BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_; DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_;
MeanVarGridDesc_M_N mean_grid_desc_m_n_; MeanVarCountGridDesc_M_N mean_var_count_grid_desc_m_n_;
MeanVarGridDesc_M_N var_grid_desc_m_n_;
GammaBetaGridDesc_N gamma_grid_desc_n_; GammaBetaGridDesc_N gamma_grid_desc_n_;
GammaBetaGridDesc_N beta_grid_desc_n_; GammaBetaGridDesc_N beta_grid_desc_n_;
HGridDesc_M_N h_grid_desc_m_n_; HGridDesc_M_N h_grid_desc_m_n_;
...@@ -551,10 +546,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -551,10 +546,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
ds_grid_desc_mblock_mperblock_nblock_nperblock_; ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemmWelford::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemmWelford::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_; e_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemmWelford::MeanGridDescriptor_MBlock_MPerBlock_NBlock typename GridwiseGemmWelford::MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
mean_grid_desc_mblock_mperblock_nblock_; mean_var_count_grid_desc_mblock_mperblock_nblock_;
typename GridwiseGemmWelford::VarGridDescriptor_MBlock_MPerBlock_NBlock
var_grid_desc_mblock_mperblock_nblock_;
// block-to-e-tile map // block-to-e-tile map
Block2ETileMap block_2_etile_map_; Block2ETileMap block_2_etile_map_;
...@@ -582,8 +575,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -582,8 +575,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.b_grid_desc_n_k_, arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_, arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
arg.mean_grid_desc_m_n_, arg.mean_var_count_grid_desc_m_n_,
arg.var_grid_desc_m_n_,
arg.block_2_etile_map_)) arg.block_2_etile_map_))
{ {
throw std::runtime_error("wrong! GridwiseGemmWelford has invalid setting"); throw std::runtime_error("wrong! GridwiseGemmWelford has invalid setting");
...@@ -615,17 +607,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -615,17 +607,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemmWelford:: typename GridwiseGemmWelford::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemmWelford::MeanGridDescriptor_MBlock_MPerBlock_NBlock, typename GridwiseGemmWelford::
typename GridwiseGemmWelford::VarGridDescriptor_MBlock_MPerBlock_NBlock, MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock,
typename GridwiseGemmWelford::DefaultBlock2ETileMap, typename GridwiseGemmWelford::DefaultBlock2ETileMap,
has_main_loop>; has_main_loop>;
const auto kernel_welford_layernorm = // const auto kernel_welford_layernorm =
kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm, // kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm,
EDataType, // EDataType,
HDataType, // HDataType,
MeanDataType, // MeanDataType,
VarDataType>; // VarDataType>;
avg_time += avg_time +=
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
...@@ -637,8 +629,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -637,8 +629,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.p_b_grid_, arg.p_b_grid_,
arg.p_ds_grid_, arg.p_ds_grid_,
arg.p_e_grid_, arg.p_e_grid_,
arg.p_mean_grid_, arg.p_welford_mean_grid_,
arg.p_var_grid_, arg.p_welford_var_grid_,
arg.p_welford_count_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.cde_element_op_, arg.cde_element_op_,
...@@ -646,20 +639,19 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -646,20 +639,19 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.mean_grid_desc_mblock_mperblock_nblock_, arg.mean_var_count_grid_desc_mblock_mperblock_nblock_,
arg.var_grid_desc_mblock_mperblock_nblock_,
arg.block_2_etile_map_); arg.block_2_etile_map_);
avg_time += launch_and_time_kernel(stream_config, // avg_time += launch_and_time_kernel(stream_config,
kernel_welford_layernorm, // kernel_welford_layernorm,
dim3(grid_size), // dim3(grid_size),
dim3(BlockSize), // dim3(BlockSize),
0, // 0,
arg.p_e_grid_, // arg.p_e_grid_,
arg.p_mean_grid_, // arg.p_welford_mean_grid_,
arg.p_var_grid_, // arg.p_welford_var_grid_,
arg.p_h_grid_, // arg.p_h_grid_,
arg.blkGroupSize_); // arg.blkGroupSize_);
return avg_time; return avg_time;
}; };
......
...@@ -47,8 +47,7 @@ template <typename ABDataType, ...@@ -47,8 +47,7 @@ template <typename ABDataType,
typename BGridDesc_N_K, typename BGridDesc_N_K,
typename DsGridDesc_M_N, typename DsGridDesc_M_N,
typename EGridDesc_M_N, typename EGridDesc_M_N,
typename MeanGridDesc_M_N, typename MeanVarCountGridDesc_M_N,
typename VarGridDesc_M_N,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -80,7 +79,6 @@ template <typename ABDataType, ...@@ -80,7 +79,6 @@ template <typename ABDataType,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename PostShuffleThreadClusterSize_M_N, typename PostShuffleThreadClusterSize_M_N,
index_t PostShuffleScalarPerVector, index_t PostShuffleScalarPerVector,
index_t MeanVarTransferScalarPerVector,
LoopScheduler LoopSched> LoopScheduler LoopSched>
struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
{ {
...@@ -242,10 +240,10 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -242,10 +240,10 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
// TODO - MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock // TODO - MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
template <typename GridDescriptor_M_N> template <typename GridDescriptor_M_N>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n) MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n)
{ {
const auto M = grid_desc_m_n.GetLength(I0); const auto M = grid_desc_m_n.GetLength(I0);
const auto NBlock = grid_desc_m_n.GetLength(I1); const auto NBlock = grid_desc_m_n.GetLength(I1);
...@@ -276,8 +274,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -276,8 +274,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
const BGridDesc_N_K& b_grid_desc_n_k, const BGridDesc_N_K& b_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n, const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n, const EGridDesc_M_N& e_grid_desc_m_n,
const MeanGridDesc_M_N& mean_grid_desc_m_n, const MeanVarCountGridDesc_M_N& mean_var_count_grid_desc_m_n,
const VarGridDesc_M_N& var_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map) const Block2ETileMap& block_2_etile_map)
{ {
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
...@@ -290,9 +287,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -290,9 +287,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
// check consistency of desc // check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) &&
M == mean_grid_desc_m_n.GetLength(I0) && M == var_grid_desc_m_n.GetLength(I0) && M == mean_var_count_grid_desc_m_n.GetLength(I0) &&
N / NPerBlock == mean_grid_desc_m_n.GetLength(I1) && N / NPerBlock == mean_var_count_grid_desc_m_n.GetLength(I1)))
N / NPerBlock == var_grid_desc_m_n.GetLength(I1)))
{ {
return false; return false;
} }
...@@ -356,10 +352,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -356,10 +352,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using MeanGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype( using MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(MeanGridDesc_M_N{}))>; MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(MeanVarCountGridDesc_M_N{}))>;
using VarGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(VarGridDesc_M_N{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
...@@ -372,26 +366,26 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -372,26 +366,26 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename Block2ETileMap> typename Block2ETileMap>
__device__ static void __device__ static void Run(const ABDataType* __restrict__ p_a_grid,
Run(const ABDataType* __restrict__ p_a_grid, const ABDataType* __restrict__ p_b_grid,
const ABDataType* __restrict__ p_b_grid, DsGridPointer p_ds_grid,
DsGridPointer p_ds_grid, EDataType* __restrict__ p_e_grid,
EDataType* __restrict__ p_e_grid, MeanDataType* __restrict__ p_welford_mean_grid,
MeanDataType* __restrict__ p_mean_grid, VarDataType* __restrict__ p_welford_var_grid,
VarDataType* __restrict__ p_var_grid, int32_t* __restrict__ p_welford_count,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op, const CDEElementwiseOperation& cde_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
const MeanGridDescriptor_MBlock_MPerBlock_NBlock& mean_grid_desc_mblock_mperblock_nblock, const MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock&
const VarGridDescriptor_MBlock_MPerBlock_NBlock& var_grid_desc_mblock_mperblock_nblock, mean_var_count_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap& block_2_etile_map) const Block2ETileMap& block_2_etile_map)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
...@@ -411,10 +405,16 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -411,10 +405,16 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto mean_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto mean_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_mean_grid, mean_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize()); p_welford_mean_grid,
mean_var_count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
auto var_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto var_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_var_grid, var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize()); p_welford_var_grid,
mean_var_count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
auto welford_count_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_count,
mean_var_count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx =
...@@ -871,9 +871,14 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -871,9 +871,14 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>( decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize())); thread_welford_dst_desc_m.GetElementSpaceSize()));
using welford_count_vgpr_type =
decltype(make_static_buffer<AddressSpaceEnum::Vgpr, int32_t>(
thread_welford_dst_desc_m.GetElementSpaceSize()));
Array<ThreadwiseWelford, num_shuffleM> threadwise_welfords; Array<ThreadwiseWelford, num_shuffleM> threadwise_welfords;
Array<mean_var_vgpr_type, num_shuffleM> mean_thread_bufs; Array<mean_var_vgpr_type, num_shuffleM> mean_thread_bufs;
Array<mean_var_vgpr_type, num_shuffleM> var_thread_bufs; Array<mean_var_vgpr_type, num_shuffleM> var_thread_bufs;
Array<welford_count_vgpr_type, num_shuffleM> welford_count_thread_bufs;
static_for<0, num_shuffleM, 1>{}([&](auto i) { static_for<0, num_shuffleM, 1>{}([&](auto i) {
// TODO - padding // TODO - padding
...@@ -884,9 +889,13 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -884,9 +889,13 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
var_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>( var_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize()); thread_welford_dst_desc_m.GetElementSpaceSize());
welford_count_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, int32_t>(
thread_welford_dst_desc_m.GetElementSpaceSize());
static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) { static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
mean_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f); mean_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
var_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f); var_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
welford_count_thread_bufs(i)(j) = 0;
}); });
}); });
...@@ -982,13 +991,14 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -982,13 +991,14 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
// Blockwise welford and write out // Blockwise welford and write out
static_for<0, num_shuffleM, 1>{}([&](auto i) { static_for<0, num_shuffleM, 1>{}([&](auto i) {
auto& mean_thread_buf = mean_thread_bufs(i); auto& mean_thread_buf = mean_thread_bufs(i);
auto& var_thread_buf = var_thread_bufs(i); auto& var_thread_buf = var_thread_bufs(i);
int count = threadwise_welfords(i).cur_count_; auto& count_thread_buf = welford_count_thread_bufs(i);
static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) { static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
block_sync_lds(); block_sync_lds();
BlockwiseWelford::Run(mean_thread_buf(j), var_thread_buf(j), count); BlockwiseWelford::Run(
mean_thread_buf(j), var_thread_buf(j), count_thread_buf(j));
}); });
constexpr auto thread_welford_desc_I_m_I = make_naive_tensor_descriptor_packed( constexpr auto thread_welford_desc_I_m_I = make_naive_tensor_descriptor_packed(
...@@ -997,20 +1007,19 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -997,20 +1007,19 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
constexpr int shuffleMPerBlock = constexpr int shuffleMPerBlock =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1);
static_assert(PostShuffleThreadSliceSize_M % MeanVarTransferScalarPerVector == 0);
auto mean_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto mean_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType, AccDataType,
MeanDataType, MeanDataType,
decltype(thread_welford_desc_I_m_I), decltype(thread_welford_desc_I_m_I),
decltype(mean_grid_desc_mblock_mperblock_nblock), decltype(mean_var_count_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>, Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
1, 1,
MeanVarTransferScalarPerVector, 1,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>{mean_grid_desc_mblock_mperblock_nblock, false>{mean_var_count_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i + shuffleMPerBlock * i +
post_shuffle_thread_data_idx_begin[I0], // mperblock post_shuffle_thread_data_idx_begin[I0], // mperblock
...@@ -1021,32 +1030,59 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -1021,32 +1030,59 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
AccDataType, AccDataType,
VarDataType, VarDataType,
decltype(thread_welford_desc_I_m_I), decltype(thread_welford_desc_I_m_I),
decltype(var_grid_desc_mblock_mperblock_nblock), decltype(mean_var_count_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
false>{mean_var_count_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i +
post_shuffle_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]), // nblock
tensor_operation::element_wise::PassThrough{}};
auto count_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
int32_t,
int32_t,
decltype(thread_welford_desc_I_m_I),
decltype(mean_var_count_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>, Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
1, 1,
MeanVarTransferScalarPerVector, 1,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>{var_grid_desc_mblock_mperblock_nblock, false>{mean_var_count_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i + shuffleMPerBlock * i +
post_shuffle_thread_data_idx_begin[I0], // mperblock post_shuffle_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]), // nblock block_work_idx[I1]), // nblock
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
mean_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I, mean_thread_copy_vgpr_to_global.Run(
make_tuple(I0, I0, I0), thread_welford_desc_I_m_I,
mean_thread_buf, make_tuple(I0, I0, I0),
mean_grid_desc_mblock_mperblock_nblock, mean_thread_buf,
mean_grid_buf); mean_var_count_grid_desc_mblock_mperblock_nblock,
mean_grid_buf);
var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I, var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
var_thread_buf, var_thread_buf,
var_grid_desc_mblock_mperblock_nblock, mean_var_count_grid_desc_mblock_mperblock_nblock,
var_grid_buf); var_grid_buf);
count_thread_copy_vgpr_to_global.Run(
thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
count_thread_buf,
mean_var_count_grid_desc_mblock_mperblock_nblock,
welford_count_grid_buf);
}); });
} // shuffle C + Ds + welford + write out } // shuffle C + Ds + welford + write out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment