Commit da047ec1 authored by ltqin's avatar ltqin
Browse files

change to input lds memory

parent 23262ab6
...@@ -27,13 +27,13 @@ struct BlockwiseSoftmax_V1 ...@@ -27,13 +27,13 @@ struct BlockwiseSoftmax_V1
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
constexpr static auto c_thread_desc = make_naive_tensor_descriptor_packed( constexpr static auto c_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<RegSizePerXdlops>{})); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<RegSizePerXdlops>{}));
template <typename CThreadBuffer> template <typename CThreadBuffer>
__host__ __device__ static void Run(CThreadBuffer& c_thread_buf) __host__ __device__ static void Run(CThreadBuffer& c_thread_buf, void* __restrict__ p_shared)
{ {
// printf("c_thread_desc: {%d, %d, %d}", c_thread_desc.GetLength(I0).value, // printf("c_thread_desc: {%d, %d, %d}", c_thread_desc.GetLength(I0).value,
// c_thread_desc.GetLength(I1).value, c_thread_desc.GetLength(I2).value); // c_thread_desc.GetLength(I1).value, c_thread_desc.GetLength(I2).value);
__shared__ AccDataType p_reduce_work_buffer[BlockSize]; auto p_reduce_work_buffer = static_cast<AccDataType*>(p_shared);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> max_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> max_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
max_value_buf(I) = reduce::Max::template GetIdentityValue<AccDataType>(); max_value_buf(I) = reduce::Max::template GetIdentityValue<AccDataType>();
......
...@@ -473,6 +473,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -473,6 +473,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf, c_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
{ {
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
using BlockwiseSoftmax = BlockwiseSoftmax_V1<BlockSize, using BlockwiseSoftmax = BlockwiseSoftmax_V1<BlockSize,
FloatAcc, FloatAcc,
MPerXDL, MPerXDL,
...@@ -482,7 +483,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -482,7 +483,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
NXdlPerWave, NXdlPerWave,
1, 1,
1>; 1>;
BlockwiseSoftmax::Run(c_thread_buf); BlockwiseSoftmax::Run(c_thread_buf, p_reduce_work_buffer);
} }
// output: register to global memory // output: register to global memory
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment