Commit 003ec407 authored by rocking's avatar rocking
Browse files

Add welford count

parent b7f500f0
......@@ -47,8 +47,7 @@ template <typename ABDataType,
typename BGridDesc_N_K,
typename DsGridDesc_M_N,
typename EGridDesc_M_N,
typename MeanGridDesc_M_N,
typename VarGridDesc_M_N,
typename MeanVarCountGridDesc_M_N,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
......@@ -80,7 +79,6 @@ template <typename ABDataType,
index_t CShuffleNXdlPerWavePerShuffle,
typename PostShuffleThreadClusterSize_M_N,
index_t PostShuffleScalarPerVector,
index_t MeanVarTransferScalarPerVector,
LoopScheduler LoopSched>
struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
{
......@@ -242,10 +240,10 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
Number<NumDTensor>{});
}
// TODO - MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
// TODO - MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
template <typename GridDescriptor_M_N>
__host__ __device__ static constexpr auto
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n)
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n)
{
const auto M = grid_desc_m_n.GetLength(I0);
const auto NBlock = grid_desc_m_n.GetLength(I1);
......@@ -276,8 +274,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
const BGridDesc_N_K& b_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n,
const MeanGridDesc_M_N& mean_grid_desc_m_n,
const VarGridDesc_M_N& var_grid_desc_m_n,
const MeanVarCountGridDesc_M_N& mean_var_count_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
......@@ -290,9 +287,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
// check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) &&
M == mean_grid_desc_m_n.GetLength(I0) && M == var_grid_desc_m_n.GetLength(I0) &&
N / NPerBlock == mean_grid_desc_m_n.GetLength(I1) &&
N / NPerBlock == var_grid_desc_m_n.GetLength(I1)))
M == mean_var_count_grid_desc_m_n.GetLength(I0) &&
N / NPerBlock == mean_var_count_grid_desc_m_n.GetLength(I1)))
{
return false;
}
......@@ -356,10 +352,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using MeanGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(MeanGridDesc_M_N{}))>;
using VarGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(VarGridDesc_M_N{}))>;
using MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(MeanVarCountGridDesc_M_N{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
......@@ -372,26 +366,26 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename Block2ETileMap>
__device__ static void
Run(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
MeanDataType* __restrict__ p_mean_grid,
VarDataType* __restrict__ p_var_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
const MeanGridDescriptor_MBlock_MPerBlock_NBlock& mean_grid_desc_mblock_mperblock_nblock,
const VarGridDescriptor_MBlock_MPerBlock_NBlock& var_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap& block_2_etile_map)
__device__ static void Run(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
MeanDataType* __restrict__ p_welford_mean_grid,
VarDataType* __restrict__ p_welford_var_grid,
int32_t* __restrict__ p_welford_count,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
const MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock&
mean_var_count_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap& block_2_etile_map)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
......@@ -411,10 +405,16 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto mean_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_mean_grid, mean_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
p_welford_mean_grid,
mean_var_count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
auto var_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_var_grid, var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
p_welford_var_grid,
mean_var_count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
auto welford_count_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_count,
mean_var_count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_work_idx =
......@@ -871,9 +871,14 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize()));
using welford_count_vgpr_type =
decltype(make_static_buffer<AddressSpaceEnum::Vgpr, int32_t>(
thread_welford_dst_desc_m.GetElementSpaceSize()));
Array<ThreadwiseWelford, num_shuffleM> threadwise_welfords;
Array<mean_var_vgpr_type, num_shuffleM> mean_thread_bufs;
Array<mean_var_vgpr_type, num_shuffleM> var_thread_bufs;
Array<welford_count_vgpr_type, num_shuffleM> welford_count_thread_bufs;
static_for<0, num_shuffleM, 1>{}([&](auto i) {
// TODO - padding
......@@ -884,9 +889,13 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
var_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize());
welford_count_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, int32_t>(
thread_welford_dst_desc_m.GetElementSpaceSize());
static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
mean_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
var_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
mean_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
var_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
welford_count_thread_bufs(i)(j) = 0;
});
});
......@@ -982,13 +991,14 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
// Blockwise welford and write out
static_for<0, num_shuffleM, 1>{}([&](auto i) {
auto& mean_thread_buf = mean_thread_bufs(i);
auto& var_thread_buf = var_thread_bufs(i);
int count = threadwise_welfords(i).cur_count_;
auto& mean_thread_buf = mean_thread_bufs(i);
auto& var_thread_buf = var_thread_bufs(i);
auto& count_thread_buf = welford_count_thread_bufs(i);
static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
block_sync_lds();
BlockwiseWelford::Run(mean_thread_buf(j), var_thread_buf(j), count);
BlockwiseWelford::Run(
mean_thread_buf(j), var_thread_buf(j), count_thread_buf(j));
});
constexpr auto thread_welford_desc_I_m_I = make_naive_tensor_descriptor_packed(
......@@ -997,20 +1007,19 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
constexpr int shuffleMPerBlock =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1);
static_assert(PostShuffleThreadSliceSize_M % MeanVarTransferScalarPerVector == 0);
auto mean_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
MeanDataType,
decltype(thread_welford_desc_I_m_I),
decltype(mean_grid_desc_mblock_mperblock_nblock),
decltype(mean_var_count_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>,
1,
MeanVarTransferScalarPerVector,
1,
InMemoryDataOperationEnum::Set,
1,
false>{mean_grid_desc_mblock_mperblock_nblock,
false>{mean_var_count_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i +
post_shuffle_thread_data_idx_begin[I0], // mperblock
......@@ -1021,32 +1030,59 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
AccDataType,
VarDataType,
decltype(thread_welford_desc_I_m_I),
decltype(var_grid_desc_mblock_mperblock_nblock),
decltype(mean_var_count_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
false>{mean_var_count_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i +
post_shuffle_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]), // nblock
tensor_operation::element_wise::PassThrough{}};
auto count_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
int32_t,
int32_t,
decltype(thread_welford_desc_I_m_I),
decltype(mean_var_count_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>,
1,
MeanVarTransferScalarPerVector,
1,
InMemoryDataOperationEnum::Set,
1,
false>{var_grid_desc_mblock_mperblock_nblock,
false>{mean_var_count_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i +
post_shuffle_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]), // nblock
tensor_operation::element_wise::PassThrough{}};
mean_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
mean_thread_buf,
mean_grid_desc_mblock_mperblock_nblock,
mean_grid_buf);
mean_thread_copy_vgpr_to_global.Run(
thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
mean_thread_buf,
mean_var_count_grid_desc_mblock_mperblock_nblock,
mean_grid_buf);
var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
var_thread_buf,
var_grid_desc_mblock_mperblock_nblock,
mean_var_count_grid_desc_mblock_mperblock_nblock,
var_grid_buf);
count_thread_copy_vgpr_to_global.Run(
thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
count_thread_buf,
mean_var_count_grid_desc_mblock_mperblock_nblock,
welford_count_grid_buf);
});
} // shuffle C + Ds + welford + write out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment