Commit 003ec407 authored by rocking's avatar rocking
Browse files

Add welford count

parent b7f500f0
......@@ -13,7 +13,8 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
// #include
// "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "device_base.hpp"
......@@ -33,8 +34,7 @@ template <typename GridwiseGemmWelford,
typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename MeanGridDescriptor_MBlock_MPerBlock_NBlock,
typename VarGridDescriptor_MBlock_MPerBlock_NBlock,
typename MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock,
typename Block2ETileMap,
bool HasMainKBlockLoop>
__global__ void
......@@ -46,8 +46,9 @@ __global__ void
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
MeanDataType* __restrict__ p_mean_grid,
VarDataType* __restrict__ p_var_grid,
MeanDataType* __restrict__ p_welford_mean_grid,
VarDataType* __restrict__ p_welford_var_grid,
int32_t* __restrict__ p_welford_count_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
......@@ -57,8 +58,8 @@ __global__ void
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const MeanGridDescriptor_MBlock_MPerBlock_NBlock mean_grid_desc_mblock_mperblock_nblock,
const VarGridDescriptor_MBlock_MPerBlock_NBlock var_grid_desc_mblock_mperblock_nblock,
const MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
mean_var_count_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
......@@ -69,8 +70,9 @@ __global__ void
p_b_grid,
p_ds_grid,
p_e_grid,
p_mean_grid,
p_var_grid,
p_welford_mean_grid,
p_welford_var_grid,
p_welford_count_grid,
p_shared,
a_element_op,
b_element_op,
......@@ -79,16 +81,16 @@ __global__ void
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
mean_grid_desc_mblock_mperblock_nblock,
var_grid_desc_mblock_mperblock_nblock,
mean_var_count_grid_desc_mblock_mperblock_nblock,
block_2_etile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = p_mean_grid;
ignore = p_var_grid;
ignore = p_welford_mean_grid;
ignore = p_welford_var_grid;
ignore = p_welford_count_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
......@@ -96,29 +98,28 @@ __global__ void
ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = mean_grid_desc_mblock_mperblock_nblock;
ignore = var_grid_desc_mblock_mperblock_nblock;
ignore = mean_var_count_grid_desc_mblock_mperblock_nblock;
ignore = block_2_etile_map;
#endif
}
template <typename GridwiseWelfordLayernorm,
typename EDataType,
typename HDataType,
typename MeanDataType,
typename VarDataType>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_welford_layernorm2d_second_half(const EDataType* __restrict__ p_x_grid,
const MeanDataType* __restrict__ p_mean_grid,
const VarDataType* __restrict__ p_var_grid,
HDataType* __restrict__ p_y_grid,
index_t blkgroup_size)
{
GridwiseWelfordLayernorm::Run(p_x_grid, p_mean_grid, p_var_grid, p_y_grid, blkgroup_size);
}
// template <typename GridwiseWelfordLayernorm,
// typename EDataType,
// typename HDataType,
// typename MeanDataType,
// typename VarDataType>
// __global__ void
// #if CK_USE_LAUNCH_BOUNDS
// __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
// #endif
// kernel_welford_layernorm2d_second_half(const EDataType* __restrict__ p_x_grid,
// const MeanDataType* __restrict__ p_mean_grid,
// const VarDataType* __restrict__ p_var_grid,
// HDataType* __restrict__ p_y_grid,
// index_t blkgroup_size)
// {
// // GridwiseWelfordLayernorm::Run(p_x_grid, p_mean_grid, p_var_grid, p_y_grid, blkgroup_size);
// }
} // namespace ck
......@@ -326,14 +327,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
}
};
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using MeanVarGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using GammaBetaGridDesc_N = decltype(MakeDescriptor_N(1));
using MeanVarGridDesc_M = decltype(MakeDescriptor_M(1));
using HGridDesc_M_N = decltype(MakeGridDescriptor_M_N<HLayout>(1, 1, 1));
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using MeanVarCountGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using GammaBetaGridDesc_N = decltype(MakeDescriptor_N(1));
using MeanVarGridDesc_M = decltype(MakeDescriptor_M(1));
using HGridDesc_M_N = decltype(MakeGridDescriptor_M_N<HLayout>(1, 1, 1));
using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
......@@ -351,8 +352,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N,
MeanVarGridDesc_M_N,
MeanVarGridDesc_M_N,
MeanVarCountGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
......@@ -384,32 +384,31 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
CShuffleNXdlPerWavePerShuffle,
PostShuffleThreadClusterSize_M_N,
PostShuffleScalarPerVector,
1,
LoopSched>;
using Block2ETileMap = typename GridwiseGemmWelford::DefaultBlock2ETileMap;
using GridwiseWelfordLayernorm =
GridwiseWelfordSecondHalfLayernorm2d<EDataType,
HDataType,
MeanDataType,
VarDataType,
AccDataType,
HGridDesc_M_N,
MeanVarGridDesc_M_N,
GammaBetaGridDesc_N,
MeanVarGridDesc_M,
BlockSize,
LayernormThreadClusterSize_M_N::At(I0),
LayernormThreadClusterSize_M_N::At(I1),
LayernormThreadSliceSize_M_N::At(I0),
LayernormThreadSliceSize_M_N::At(I1),
LayernormESrcHDstVectorDim,
LayernormESrcVectorSize,
LayernormHDstVectorSize,
LayernormGammaSrcVectorSize,
LayernormBetaSrcVectorSize,
LayernormMeanVarSrcDstVectorSize>;
// using GridwiseWelfordLayernorm =
// GridwiseWelfordSecondHalfLayernorm2d<EDataType,
// HDataType,
// MeanDataType,
// VarDataType,
// AccDataType,
// HGridDesc_M_N,
// MeanVarGridDesc_M_N,
// GammaBetaGridDesc_N,
// MeanVarGridDesc_M,
// BlockSize,
// LayernormThreadClusterSize_M_N::At(I0),
// LayernormThreadClusterSize_M_N::At(I1),
// LayernormThreadSliceSize_M_N::At(I0),
// LayernormThreadSliceSize_M_N::At(I1),
// LayernormESrcHDstVectorDim,
// LayernormESrcVectorSize,
// LayernormHDstVectorSize,
// LayernormGammaSrcVectorSize,
// LayernormBetaSrcVectorSize,
// LayernormMeanVarSrcDstVectorSize>;
// Argument
struct Argument : public BaseArgument
......@@ -436,8 +435,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{},
p_e_grid_{nullptr},
p_mean_grid_{nullptr},
p_var_grid_{nullptr},
p_welford_mean_grid_{nullptr},
p_welford_var_grid_{nullptr},
p_welford_count_grid_{nullptr},
p_gamma_grid_{static_cast<const GammaDataType*>(p_gamma_grid)},
p_beta_grid_{static_cast<const BetaDataType*>(p_beta_grid)},
p_h_grid_{static_cast<HDataType*>(p_h_grid)},
......@@ -445,8 +445,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)},
mean_grid_desc_m_n_{},
var_grid_desc_m_n_{},
mean_var_count_grid_desc_m_n_{},
gamma_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
beta_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
h_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
......@@ -458,16 +457,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
blkGroupSize_{math::integer_divide_ceil(NRaw, NPerBlock)},
epsilon_{epsilon}
{
mean_grid_desc_m_n_ =
DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, blkGroupSize_, blkGroupSize_);
var_grid_desc_m_n_ =
mean_var_count_grid_desc_m_n_ =
DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, blkGroupSize_, blkGroupSize_);
hip_check_error(hipMalloc(&p_e_grid_, sizeof(EDataType) * MRaw * NRaw));
int gemm_welford_size = MRaw * blkGroupSize_;
hip_check_error(hipMalloc(&p_mean_grid_, sizeof(MeanDataType) * gemm_welford_size));
hip_check_error(hipMalloc(&p_var_grid_, sizeof(VarDataType) * gemm_welford_size));
hip_check_error(
hipMalloc(&p_welford_mean_grid_, sizeof(MeanDataType) * gemm_welford_size));
hip_check_error(
hipMalloc(&p_welford_var_grid_, sizeof(VarDataType) * gemm_welford_size));
hip_check_error(hipMalloc(&p_welford_count_grid_, sizeof(int32_t) * gemm_welford_size));
// populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
......@@ -487,8 +487,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
mean_grid_desc_m_n_,
var_grid_desc_m_n_,
mean_var_count_grid_desc_m_n_,
block_2_etile_map_))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
......@@ -499,13 +498,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
GridwiseGemmWelford::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
mean_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemmWelford::MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(
mean_grid_desc_m_n_);
var_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemmWelford::MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(
var_grid_desc_m_n_);
mean_var_count_grid_desc_mblock_mperblock_nblock_ =
GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
mean_var_count_grid_desc_m_n_);
}
// TODO - H
......@@ -527,8 +522,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
const BDataType* p_b_grid_;
typename GridwiseGemmWelford::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
MeanDataType* p_mean_grid_; // mean
VarDataType* p_var_grid_; // variance * count
MeanDataType* p_welford_mean_grid_;
VarDataType* p_welford_var_grid_;
int32_t* p_welford_count_grid_;
const GammaDataType* p_gamma_grid_;
const BetaDataType* p_beta_grid_;
HDataType* p_h_grid_;
......@@ -538,8 +534,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
MeanVarGridDesc_M_N mean_grid_desc_m_n_;
MeanVarGridDesc_M_N var_grid_desc_m_n_;
MeanVarCountGridDesc_M_N mean_var_count_grid_desc_m_n_;
GammaBetaGridDesc_N gamma_grid_desc_n_;
GammaBetaGridDesc_N beta_grid_desc_n_;
HGridDesc_M_N h_grid_desc_m_n_;
......@@ -551,10 +546,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemmWelford::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemmWelford::MeanGridDescriptor_MBlock_MPerBlock_NBlock
mean_grid_desc_mblock_mperblock_nblock_;
typename GridwiseGemmWelford::VarGridDescriptor_MBlock_MPerBlock_NBlock
var_grid_desc_mblock_mperblock_nblock_;
typename GridwiseGemmWelford::MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
mean_var_count_grid_desc_mblock_mperblock_nblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
......@@ -582,8 +575,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.mean_grid_desc_m_n_,
arg.var_grid_desc_m_n_,
arg.mean_var_count_grid_desc_m_n_,
arg.block_2_etile_map_))
{
throw std::runtime_error("wrong! GridwiseGemmWelford has invalid setting");
......@@ -615,17 +607,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemmWelford::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemmWelford::MeanGridDescriptor_MBlock_MPerBlock_NBlock,
typename GridwiseGemmWelford::VarGridDescriptor_MBlock_MPerBlock_NBlock,
typename GridwiseGemmWelford::
MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock,
typename GridwiseGemmWelford::DefaultBlock2ETileMap,
has_main_loop>;
const auto kernel_welford_layernorm =
kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm,
EDataType,
HDataType,
MeanDataType,
VarDataType>;
// const auto kernel_welford_layernorm =
// kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm,
// EDataType,
// HDataType,
// MeanDataType,
// VarDataType>;
avg_time +=
launch_and_time_kernel(stream_config,
......@@ -637,8 +629,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.p_mean_grid_,
arg.p_var_grid_,
arg.p_welford_mean_grid_,
arg.p_welford_var_grid_,
arg.p_welford_count_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
......@@ -646,20 +639,19 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.mean_grid_desc_mblock_mperblock_nblock_,
arg.var_grid_desc_mblock_mperblock_nblock_,
arg.mean_var_count_grid_desc_mblock_mperblock_nblock_,
arg.block_2_etile_map_);
avg_time += launch_and_time_kernel(stream_config,
kernel_welford_layernorm,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_e_grid_,
arg.p_mean_grid_,
arg.p_var_grid_,
arg.p_h_grid_,
arg.blkGroupSize_);
// avg_time += launch_and_time_kernel(stream_config,
// kernel_welford_layernorm,
// dim3(grid_size),
// dim3(BlockSize),
// 0,
// arg.p_e_grid_,
// arg.p_welford_mean_grid_,
// arg.p_welford_var_grid_,
// arg.p_h_grid_,
// arg.blkGroupSize_);
return avg_time;
};
......
......@@ -47,8 +47,7 @@ template <typename ABDataType,
typename BGridDesc_N_K,
typename DsGridDesc_M_N,
typename EGridDesc_M_N,
typename MeanGridDesc_M_N,
typename VarGridDesc_M_N,
typename MeanVarCountGridDesc_M_N,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
......@@ -80,7 +79,6 @@ template <typename ABDataType,
index_t CShuffleNXdlPerWavePerShuffle,
typename PostShuffleThreadClusterSize_M_N,
index_t PostShuffleScalarPerVector,
index_t MeanVarTransferScalarPerVector,
LoopScheduler LoopSched>
struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
{
......@@ -242,10 +240,10 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
Number<NumDTensor>{});
}
// TODO - MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
// TODO - MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
template <typename GridDescriptor_M_N>
__host__ __device__ static constexpr auto
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n)
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n)
{
const auto M = grid_desc_m_n.GetLength(I0);
const auto NBlock = grid_desc_m_n.GetLength(I1);
......@@ -276,8 +274,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
const BGridDesc_N_K& b_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n,
const MeanGridDesc_M_N& mean_grid_desc_m_n,
const VarGridDesc_M_N& var_grid_desc_m_n,
const MeanVarCountGridDesc_M_N& mean_var_count_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
......@@ -290,9 +287,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
// check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) &&
M == mean_grid_desc_m_n.GetLength(I0) && M == var_grid_desc_m_n.GetLength(I0) &&
N / NPerBlock == mean_grid_desc_m_n.GetLength(I1) &&
N / NPerBlock == var_grid_desc_m_n.GetLength(I1)))
M == mean_var_count_grid_desc_m_n.GetLength(I0) &&
N / NPerBlock == mean_var_count_grid_desc_m_n.GetLength(I1)))
{
return false;
}
......@@ -356,10 +352,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using MeanGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(MeanGridDesc_M_N{}))>;
using VarGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(VarGridDesc_M_N{}))>;
using MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(MeanVarCountGridDesc_M_N{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
......@@ -372,26 +366,26 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename Block2ETileMap>
__device__ static void
Run(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
MeanDataType* __restrict__ p_mean_grid,
VarDataType* __restrict__ p_var_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
const MeanGridDescriptor_MBlock_MPerBlock_NBlock& mean_grid_desc_mblock_mperblock_nblock,
const VarGridDescriptor_MBlock_MPerBlock_NBlock& var_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap& block_2_etile_map)
__device__ static void Run(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
MeanDataType* __restrict__ p_welford_mean_grid,
VarDataType* __restrict__ p_welford_var_grid,
int32_t* __restrict__ p_welford_count,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
const MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock&
mean_var_count_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap& block_2_etile_map)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
......@@ -411,10 +405,16 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto mean_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_mean_grid, mean_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
p_welford_mean_grid,
mean_var_count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
auto var_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_var_grid, var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
p_welford_var_grid,
mean_var_count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
auto welford_count_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_count,
mean_var_count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_work_idx =
......@@ -871,9 +871,14 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize()));
using welford_count_vgpr_type =
decltype(make_static_buffer<AddressSpaceEnum::Vgpr, int32_t>(
thread_welford_dst_desc_m.GetElementSpaceSize()));
Array<ThreadwiseWelford, num_shuffleM> threadwise_welfords;
Array<mean_var_vgpr_type, num_shuffleM> mean_thread_bufs;
Array<mean_var_vgpr_type, num_shuffleM> var_thread_bufs;
Array<welford_count_vgpr_type, num_shuffleM> welford_count_thread_bufs;
static_for<0, num_shuffleM, 1>{}([&](auto i) {
// TODO - padding
......@@ -884,9 +889,13 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
var_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize());
welford_count_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, int32_t>(
thread_welford_dst_desc_m.GetElementSpaceSize());
static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
mean_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
var_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
mean_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
var_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
welford_count_thread_bufs(i)(j) = 0;
});
});
......@@ -982,13 +991,14 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
// Blockwise welford and write out
static_for<0, num_shuffleM, 1>{}([&](auto i) {
auto& mean_thread_buf = mean_thread_bufs(i);
auto& var_thread_buf = var_thread_bufs(i);
int count = threadwise_welfords(i).cur_count_;
auto& mean_thread_buf = mean_thread_bufs(i);
auto& var_thread_buf = var_thread_bufs(i);
auto& count_thread_buf = welford_count_thread_bufs(i);
static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
block_sync_lds();
BlockwiseWelford::Run(mean_thread_buf(j), var_thread_buf(j), count);
BlockwiseWelford::Run(
mean_thread_buf(j), var_thread_buf(j), count_thread_buf(j));
});
constexpr auto thread_welford_desc_I_m_I = make_naive_tensor_descriptor_packed(
......@@ -997,20 +1007,19 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
constexpr int shuffleMPerBlock =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1);
static_assert(PostShuffleThreadSliceSize_M % MeanVarTransferScalarPerVector == 0);
auto mean_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
MeanDataType,
decltype(thread_welford_desc_I_m_I),
decltype(mean_grid_desc_mblock_mperblock_nblock),
decltype(mean_var_count_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>,
1,
MeanVarTransferScalarPerVector,
1,
InMemoryDataOperationEnum::Set,
1,
false>{mean_grid_desc_mblock_mperblock_nblock,
false>{mean_var_count_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i +
post_shuffle_thread_data_idx_begin[I0], // mperblock
......@@ -1021,32 +1030,59 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
AccDataType,
VarDataType,
decltype(thread_welford_desc_I_m_I),
decltype(var_grid_desc_mblock_mperblock_nblock),
decltype(mean_var_count_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
false>{mean_var_count_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i +
post_shuffle_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]), // nblock
tensor_operation::element_wise::PassThrough{}};
auto count_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
int32_t,
int32_t,
decltype(thread_welford_desc_I_m_I),
decltype(mean_var_count_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>,
1,
MeanVarTransferScalarPerVector,
1,
InMemoryDataOperationEnum::Set,
1,
false>{var_grid_desc_mblock_mperblock_nblock,
false>{mean_var_count_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i +
post_shuffle_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]), // nblock
tensor_operation::element_wise::PassThrough{}};
mean_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
mean_thread_buf,
mean_grid_desc_mblock_mperblock_nblock,
mean_grid_buf);
mean_thread_copy_vgpr_to_global.Run(
thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
mean_thread_buf,
mean_var_count_grid_desc_mblock_mperblock_nblock,
mean_grid_buf);
var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
var_thread_buf,
var_grid_desc_mblock_mperblock_nblock,
mean_var_count_grid_desc_mblock_mperblock_nblock,
var_grid_buf);
count_thread_copy_vgpr_to_global.Run(
thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
count_thread_buf,
mean_var_count_grid_desc_mblock_mperblock_nblock,
welford_count_grid_buf);
});
} // shuffle C + Ds + welford + write out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment