Commit 2717e60d authored by ltqin's avatar ltqin
Browse files

code regular

parent d8154515
...@@ -91,12 +91,10 @@ struct BlockwiseSoftmax_V1 ...@@ -91,12 +91,10 @@ struct BlockwiseSoftmax_V1
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>; detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
template <typename CThreadBuffer> template <typename CThreadBuffer>
__host__ __device__ static void __host__ __device__ static void
Run(CThreadBuffer& in_thread_buf, void* __restrict__ p_shared, void* __restrict__ p_softmax) Run(CThreadBuffer& in_thread_buf, void* __restrict__ p_reduce, void* __restrict__ p_softmax)
{ {
// printf("in_thread_desc: {%d, %d, %d}", in_thread_desc.GetLength(I0).value,
// in_thread_desc.GetLength(I1).value, in_thread_desc.GetLength(I2).value);
auto reduce_work_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto reduce_work_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<AccDataType*>(p_shared), BlockSize); static_cast<AccDataType*>(p_reduce), BlockSize);
auto softmax_lds_buffer = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto softmax_lds_buffer = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<AccDataType*>(p_softmax), MPerBlock * 2); static_cast<AccDataType*>(p_softmax), MPerBlock * 2);
...@@ -125,7 +123,7 @@ struct BlockwiseSoftmax_V1 ...@@ -125,7 +123,7 @@ struct BlockwiseSoftmax_V1
// block reduce for max // block reduce for max
BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I0)); BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I0));
block_sync_lds(); block_sync_lds();
// save max value // save max value to lds
if(0 == thread_k_cluster_id) if(0 == thread_k_cluster_id)
{ {
softmax_lds_buffer(softmax_buf_desc_m_k.CalculateOffset( softmax_lds_buffer(softmax_buf_desc_m_k.CalculateOffset(
...@@ -160,7 +158,7 @@ struct BlockwiseSoftmax_V1 ...@@ -160,7 +158,7 @@ struct BlockwiseSoftmax_V1
BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I0)); BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I0));
block_sync_lds(); block_sync_lds();
// save sum // save sum to lds
if(0 == thread_k_cluster_id) if(0 == thread_k_cluster_id)
{ {
softmax_lds_buffer(softmax_buf_desc_m_k.CalculateOffset( softmax_lds_buffer(softmax_buf_desc_m_k.CalculateOffset(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment