Commit 003ec407 authored by rocking's avatar rocking
Browse files

Add welford count

parent b7f500f0
...@@ -47,8 +47,7 @@ template <typename ABDataType, ...@@ -47,8 +47,7 @@ template <typename ABDataType,
typename BGridDesc_N_K, typename BGridDesc_N_K,
typename DsGridDesc_M_N, typename DsGridDesc_M_N,
typename EGridDesc_M_N, typename EGridDesc_M_N,
typename MeanGridDesc_M_N, typename MeanVarCountGridDesc_M_N,
typename VarGridDesc_M_N,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -80,7 +79,6 @@ template <typename ABDataType, ...@@ -80,7 +79,6 @@ template <typename ABDataType,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename PostShuffleThreadClusterSize_M_N, typename PostShuffleThreadClusterSize_M_N,
index_t PostShuffleScalarPerVector, index_t PostShuffleScalarPerVector,
index_t MeanVarTransferScalarPerVector,
LoopScheduler LoopSched> LoopScheduler LoopSched>
struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
{ {
...@@ -242,10 +240,10 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -242,10 +240,10 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
// TODO - MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock // TODO - MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
template <typename GridDescriptor_M_N> template <typename GridDescriptor_M_N>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n) MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n)
{ {
const auto M = grid_desc_m_n.GetLength(I0); const auto M = grid_desc_m_n.GetLength(I0);
const auto NBlock = grid_desc_m_n.GetLength(I1); const auto NBlock = grid_desc_m_n.GetLength(I1);
...@@ -276,8 +274,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -276,8 +274,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
const BGridDesc_N_K& b_grid_desc_n_k, const BGridDesc_N_K& b_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n, const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n, const EGridDesc_M_N& e_grid_desc_m_n,
const MeanGridDesc_M_N& mean_grid_desc_m_n, const MeanVarCountGridDesc_M_N& mean_var_count_grid_desc_m_n,
const VarGridDesc_M_N& var_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map) const Block2ETileMap& block_2_etile_map)
{ {
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
...@@ -290,9 +287,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -290,9 +287,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
// check consistency of desc // check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) &&
M == mean_grid_desc_m_n.GetLength(I0) && M == var_grid_desc_m_n.GetLength(I0) && M == mean_var_count_grid_desc_m_n.GetLength(I0) &&
N / NPerBlock == mean_grid_desc_m_n.GetLength(I1) && N / NPerBlock == mean_var_count_grid_desc_m_n.GetLength(I1)))
N / NPerBlock == var_grid_desc_m_n.GetLength(I1)))
{ {
return false; return false;
} }
...@@ -356,10 +352,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -356,10 +352,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using MeanGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype( using MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(MeanGridDesc_M_N{}))>; MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(MeanVarCountGridDesc_M_N{}))>;
using VarGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock(VarGridDesc_M_N{}))>;
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
...@@ -372,26 +366,26 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -372,26 +366,26 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename Block2ETileMap> typename Block2ETileMap>
__device__ static void __device__ static void Run(const ABDataType* __restrict__ p_a_grid,
Run(const ABDataType* __restrict__ p_a_grid, const ABDataType* __restrict__ p_b_grid,
const ABDataType* __restrict__ p_b_grid, DsGridPointer p_ds_grid,
DsGridPointer p_ds_grid, EDataType* __restrict__ p_e_grid,
EDataType* __restrict__ p_e_grid, MeanDataType* __restrict__ p_welford_mean_grid,
MeanDataType* __restrict__ p_mean_grid, VarDataType* __restrict__ p_welford_var_grid,
VarDataType* __restrict__ p_var_grid, int32_t* __restrict__ p_welford_count,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op, const CDEElementwiseOperation& cde_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
const MeanGridDescriptor_MBlock_MPerBlock_NBlock& mean_grid_desc_mblock_mperblock_nblock, const MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock&
const VarGridDescriptor_MBlock_MPerBlock_NBlock& var_grid_desc_mblock_mperblock_nblock, mean_var_count_grid_desc_mblock_mperblock_nblock,
const Block2ETileMap& block_2_etile_map) const Block2ETileMap& block_2_etile_map)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
...@@ -411,10 +405,16 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -411,10 +405,16 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto mean_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto mean_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_mean_grid, mean_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize()); p_welford_mean_grid,
mean_var_count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
auto var_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto var_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_var_grid, var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize()); p_welford_var_grid,
mean_var_count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
auto welford_count_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_count,
mean_var_count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx =
...@@ -871,9 +871,14 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -871,9 +871,14 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>( decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize())); thread_welford_dst_desc_m.GetElementSpaceSize()));
using welford_count_vgpr_type =
decltype(make_static_buffer<AddressSpaceEnum::Vgpr, int32_t>(
thread_welford_dst_desc_m.GetElementSpaceSize()));
Array<ThreadwiseWelford, num_shuffleM> threadwise_welfords; Array<ThreadwiseWelford, num_shuffleM> threadwise_welfords;
Array<mean_var_vgpr_type, num_shuffleM> mean_thread_bufs; Array<mean_var_vgpr_type, num_shuffleM> mean_thread_bufs;
Array<mean_var_vgpr_type, num_shuffleM> var_thread_bufs; Array<mean_var_vgpr_type, num_shuffleM> var_thread_bufs;
Array<welford_count_vgpr_type, num_shuffleM> welford_count_thread_bufs;
static_for<0, num_shuffleM, 1>{}([&](auto i) { static_for<0, num_shuffleM, 1>{}([&](auto i) {
// TODO - padding // TODO - padding
...@@ -884,9 +889,13 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -884,9 +889,13 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
var_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>( var_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
thread_welford_dst_desc_m.GetElementSpaceSize()); thread_welford_dst_desc_m.GetElementSpaceSize());
welford_count_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, int32_t>(
thread_welford_dst_desc_m.GetElementSpaceSize());
static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) { static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
mean_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f); mean_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
var_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f); var_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
welford_count_thread_bufs(i)(j) = 0;
}); });
}); });
...@@ -982,13 +991,14 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -982,13 +991,14 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
// Blockwise welford and write out // Blockwise welford and write out
static_for<0, num_shuffleM, 1>{}([&](auto i) { static_for<0, num_shuffleM, 1>{}([&](auto i) {
auto& mean_thread_buf = mean_thread_bufs(i); auto& mean_thread_buf = mean_thread_bufs(i);
auto& var_thread_buf = var_thread_bufs(i); auto& var_thread_buf = var_thread_bufs(i);
int count = threadwise_welfords(i).cur_count_; auto& count_thread_buf = welford_count_thread_bufs(i);
static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) { static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
block_sync_lds(); block_sync_lds();
BlockwiseWelford::Run(mean_thread_buf(j), var_thread_buf(j), count); BlockwiseWelford::Run(
mean_thread_buf(j), var_thread_buf(j), count_thread_buf(j));
}); });
constexpr auto thread_welford_desc_I_m_I = make_naive_tensor_descriptor_packed( constexpr auto thread_welford_desc_I_m_I = make_naive_tensor_descriptor_packed(
...@@ -997,20 +1007,19 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -997,20 +1007,19 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
constexpr int shuffleMPerBlock = constexpr int shuffleMPerBlock =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1);
static_assert(PostShuffleThreadSliceSize_M % MeanVarTransferScalarPerVector == 0);
auto mean_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto mean_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType, AccDataType,
MeanDataType, MeanDataType,
decltype(thread_welford_desc_I_m_I), decltype(thread_welford_desc_I_m_I),
decltype(mean_grid_desc_mblock_mperblock_nblock), decltype(mean_var_count_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>, Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
1, 1,
MeanVarTransferScalarPerVector, 1,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>{mean_grid_desc_mblock_mperblock_nblock, false>{mean_var_count_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i + shuffleMPerBlock * i +
post_shuffle_thread_data_idx_begin[I0], // mperblock post_shuffle_thread_data_idx_begin[I0], // mperblock
...@@ -1021,32 +1030,59 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle ...@@ -1021,32 +1030,59 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
AccDataType, AccDataType,
VarDataType, VarDataType,
decltype(thread_welford_desc_I_m_I), decltype(thread_welford_desc_I_m_I),
decltype(var_grid_desc_mblock_mperblock_nblock), decltype(mean_var_count_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
false>{mean_var_count_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i +
post_shuffle_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]), // nblock
tensor_operation::element_wise::PassThrough{}};
auto count_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
int32_t,
int32_t,
decltype(thread_welford_desc_I_m_I),
decltype(mean_var_count_grid_desc_mblock_mperblock_nblock),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<1, PostShuffleThreadSliceSize_M, 1>, Sequence<1, PostShuffleThreadSliceSize_M, 1>,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
1, 1,
MeanVarTransferScalarPerVector, 1,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>{var_grid_desc_mblock_mperblock_nblock, false>{mean_var_count_grid_desc_mblock_mperblock_nblock,
make_multi_index(block_work_idx[I0], // mblock make_multi_index(block_work_idx[I0], // mblock
shuffleMPerBlock * i + shuffleMPerBlock * i +
post_shuffle_thread_data_idx_begin[I0], // mperblock post_shuffle_thread_data_idx_begin[I0], // mperblock
block_work_idx[I1]), // nblock block_work_idx[I1]), // nblock
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
mean_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I, mean_thread_copy_vgpr_to_global.Run(
make_tuple(I0, I0, I0), thread_welford_desc_I_m_I,
mean_thread_buf, make_tuple(I0, I0, I0),
mean_grid_desc_mblock_mperblock_nblock, mean_thread_buf,
mean_grid_buf); mean_var_count_grid_desc_mblock_mperblock_nblock,
mean_grid_buf);
var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I, var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
var_thread_buf, var_thread_buf,
var_grid_desc_mblock_mperblock_nblock, mean_var_count_grid_desc_mblock_mperblock_nblock,
var_grid_buf); var_grid_buf);
count_thread_copy_vgpr_to_global.Run(
thread_welford_desc_I_m_I,
make_tuple(I0, I0, I0),
count_thread_buf,
mean_var_count_grid_desc_mblock_mperblock_nblock,
welford_count_grid_buf);
}); });
} // shuffle C + Ds + welford + write out } // shuffle C + Ds + welford + write out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment