Commit bf44991f authored by Anthony Chang's avatar Anthony Chang
Browse files

use LDS mem pool for reduction workspace

parent 3db406f0
......@@ -217,12 +217,16 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
// Align 16 bytes (maximum LDS read/write width)
constexpr auto c_block_size_aligned = math::integer_least_multiple(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() * sizeof(FloatCShuffle), 16) / sizeof(FloatCShuffle);
// LDS allocation for reduction workspace
constexpr index_t c_lds_workspace_size = BlockSize;
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB),
c_block_size * sizeof(FloatCShuffle));
c_block_size_aligned * sizeof(FloatCShuffle) + c_lds_workspace_size * sizeof(FloatReduceAcc));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
......@@ -734,11 +738,12 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto c0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatC0>(
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
// TODO ANT: incorporate in singly defined p_shared. calculate proper total size in
// GetSharedMemoryNumberOfByte() and shift pointer as approriate
__shared__ FloatReduceAcc p_d_reduce_work_buffer[BlockSize];
auto d_reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_d_reduce_work_buffer, BlockSize);
// Align 16 bytes (maximum LDS read/write width)
constexpr auto c_block_size_aligned = math::integer_least_multiple(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() * sizeof(FloatCShuffle), 16) / sizeof(FloatCShuffle);
auto d_reduce_work_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
reinterpret_cast<FloatReduceAcc*>(static_cast<FloatCShuffle*>(p_shared) + c_block_size_aligned), BlockSize);
// Sum thread workspace
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment