Commit c85e0652 authored by Jing Zhang's avatar Jing Zhang
Browse files

clean code

parent e64a79c5
......@@ -78,9 +78,8 @@ struct BlockwiseGenericTensorSliceCopy_v5
return ThreadBufferDesc::GetElementSpace();
}
template <typename BlockSrcData, typename ThreadBufferData>
__device__ void RunLoadThreadBuffer(const BlockSrcData* p_block_src,
ThreadBufferData* p_thread_buffer)
template <typename BlockSrcData>
__device__ void RunLoadThreadBuffer(const BlockSrcData* p_block_src)
{
if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
......@@ -89,9 +88,8 @@ struct BlockwiseGenericTensorSliceCopy_v5
}
}
template <typename ThreadBufferData, typename BlockDstData>
__device__ void RunStoreThreadBuffer(const ThreadBufferData* p_thread_buffer,
BlockDstData* p_block_dst)
template <typename BlockDstData>
__device__ void RunStoreThreadBuffer(BlockDstData* p_block_dst)
{
if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
......@@ -109,13 +107,11 @@ struct BlockwiseGenericTensorSliceCopy_v5
"to use ThreadBufferAddressSpace as their thread buffer, which is not vgpr. "
"Behavior may be different");
BlockSrcData p_thread_buffer[GetThreadBufferSize()];
if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
RunLoadThreadBuffer(p_block_src, p_thread_buffer);
RunStoreThreadBuffer(p_thread_buffer, p_block_dst);
RunLoadThreadBuffer(p_block_src);
RunStoreThreadBuffer(p_block_dst);
}
}
......
......@@ -127,7 +127,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
BlockwiseGenericTensorSliceCopy_v5<BlockSize,
decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc),
decltype(a_k_m_block_desc.GetLengths()),
......@@ -253,24 +253,21 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
Float* p_b_block_next =
even_loop ? p_b_block_double + b_block_space : p_b_block_double;
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
a_blockwise_copy.RunLoadThreadBuffer(p_a_global);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next);
a_blockwise_copy.RunStoreThreadBuffer(p_a_block_next);
b_blockwise_copy.RunStoreThreadBuffer(p_b_block_next);
}
}
......@@ -281,7 +278,6 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
if(has_two_iteration_left) // if has 2 iteration left
{
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
......@@ -289,17 +285,15 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
a_blockwise_copy.RunLoadThreadBuffer(p_a_global);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer,
p_a_block_double + a_block_space);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer,
p_b_block_double + b_block_space);
a_blockwise_copy.RunStoreThreadBuffer(p_a_block_double + a_block_space);
b_blockwise_copy.RunStoreThreadBuffer(p_b_block_double + b_block_space);
__syncthreads();
......
......@@ -177,7 +177,7 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment