Commit bf44991f authored by Anthony Chang's avatar Anthony Chang
Browse files

use LDS mem pool for reduction workspace

parent 3db406f0
...@@ -217,12 +217,16 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -217,12 +217,16 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_size = // Align 16 bytes (maximum LDS read/write width)
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); constexpr auto c_block_size_aligned = math::integer_least_multiple(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() * sizeof(FloatCShuffle), 16) / sizeof(FloatCShuffle);
// LDS allocation for reduction workspace
constexpr index_t c_lds_workspace_size = BlockSize;
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB), sizeof(FloatAB),
c_block_size * sizeof(FloatCShuffle)); c_block_size_aligned * sizeof(FloatCShuffle) + c_lds_workspace_size * sizeof(FloatReduceAcc));
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...@@ -734,11 +738,12 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -734,11 +738,12 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto c0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatC0>( auto c0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatC0>(
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
// TODO ANT: incorporate in singly defined p_shared. calculate proper total size in // Align 16 bytes (maximum LDS read/write width)
// GetSharedMemoryNumberOfByte() and shift pointer as approriate constexpr auto c_block_size_aligned = math::integer_least_multiple(
__shared__ FloatReduceAcc p_d_reduce_work_buffer[BlockSize]; c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() * sizeof(FloatCShuffle), 16) / sizeof(FloatCShuffle);
auto d_reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_d_reduce_work_buffer, BlockSize); auto d_reduce_work_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
reinterpret_cast<FloatReduceAcc*>(static_cast<FloatCShuffle*>(p_shared) + c_block_size_aligned), BlockSize);
// Sum thread workspace // Sum thread workspace
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>( auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment