Commit d4368d77 authored by ltqin's avatar ltqin
Browse files

max and sum save to register

parent 2717e60d
...@@ -43,9 +43,6 @@ struct BlockwiseSoftmax_V1 ...@@ -43,9 +43,6 @@ struct BlockwiseSoftmax_V1
} }
}; };
static constexpr auto softmax_buf_desc_m_k =
make_naive_tensor_descriptor_packed(make_tuple(Number<MPerBlock>{}, Number<2>{}));
constexpr static auto in_thread_desc = make_naive_tensor_descriptor_packed( constexpr static auto in_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<RegSizePerXdlops>{})); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<RegSizePerXdlops>{}));
...@@ -91,19 +88,10 @@ struct BlockwiseSoftmax_V1 ...@@ -91,19 +88,10 @@ struct BlockwiseSoftmax_V1
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>; detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
template <typename CThreadBuffer> template <typename CThreadBuffer>
__host__ __device__ static void __host__ __device__ static void
Run(CThreadBuffer& in_thread_buf, void* __restrict__ p_reduce, void* __restrict__ p_softmax) Run(CThreadBuffer& in_thread_buf, float& f_sum, float& f_max, void* __restrict__ p_reduce)
{ {
auto reduce_work_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto reduce_work_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<AccDataType*>(p_reduce), BlockSize); static_cast<AccDataType*>(p_reduce), BlockSize);
auto softmax_lds_buffer = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<AccDataType*>(p_softmax), MPerBlock * 2);
// thread id map to thread layout
const index_t thread_local_id = get_thread_local_1d_id();
const auto thread_cluster_idx =
BlockToMKMap_M0_K_M1Adapt::CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
// //
// find max value // find max value
// //
...@@ -123,12 +111,8 @@ struct BlockwiseSoftmax_V1 ...@@ -123,12 +111,8 @@ struct BlockwiseSoftmax_V1
// block reduce for max // block reduce for max
BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I0)); BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I0));
block_sync_lds(); block_sync_lds();
// save max value to lds // save max
if(0 == thread_k_cluster_id) f_max = max_value_buf(I0);
{
softmax_lds_buffer(softmax_buf_desc_m_k.CalculateOffset(
make_tuple(thread_m_cluster_id, 1))) = max_value_buf(I0);
}
// //
// softmax // softmax
...@@ -158,12 +142,8 @@ struct BlockwiseSoftmax_V1 ...@@ -158,12 +142,8 @@ struct BlockwiseSoftmax_V1
BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I0)); BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I0));
block_sync_lds(); block_sync_lds();
// save sum to lds // save sum
if(0 == thread_k_cluster_id) f_sum = accu_value_buf(I0);
{
softmax_lds_buffer(softmax_buf_desc_m_k.CalculateOffset(
make_tuple(thread_m_cluster_id, 0))) = accu_value_buf(I0);
}
} }
}; // namespace ck }; // namespace ck
......
...@@ -474,7 +474,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -474,7 +474,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
num_k_block_main_loop); num_k_block_main_loop);
{ {
__shared__ AccDataType p_reduce_work_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
__shared__ AccDataType p_lex[MPerBlock * 2]; float f_sum, f_max;
using BlockwiseSoftmax = BlockwiseSoftmax_V1<BlockSize, using BlockwiseSoftmax = BlockwiseSoftmax_V1<BlockSize,
FloatAcc, FloatAcc,
...@@ -484,7 +484,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -484,7 +484,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
blockwise_gemm.GetRegSizePerXdlops(), blockwise_gemm.GetRegSizePerXdlops(),
MXdlPerWave, MXdlPerWave,
NXdlPerWave>; NXdlPerWave>;
BlockwiseSoftmax::Run(c_thread_buf, p_reduce_work_buffer, p_lex); BlockwiseSoftmax::Run(c_thread_buf, f_sum, f_max, p_reduce_work_buffer);
} }
// output: register to global memory // output: register to global memory
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment