Commit c85e0652 authored by Jing Zhang's avatar Jing Zhang
Browse files

clean code

parent e64a79c5
...@@ -78,9 +78,8 @@ struct BlockwiseGenericTensorSliceCopy_v5 ...@@ -78,9 +78,8 @@ struct BlockwiseGenericTensorSliceCopy_v5
return ThreadBufferDesc::GetElementSpace(); return ThreadBufferDesc::GetElementSpace();
} }
template <typename BlockSrcData, typename ThreadBufferData> template <typename BlockSrcData>
__device__ void RunLoadThreadBuffer(const BlockSrcData* p_block_src, __device__ void RunLoadThreadBuffer(const BlockSrcData* p_block_src)
ThreadBufferData* p_thread_buffer)
{ {
if(BlockSize == mThreadClusterDesc.GetElementSize() or if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize()) get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
...@@ -89,9 +88,8 @@ struct BlockwiseGenericTensorSliceCopy_v5 ...@@ -89,9 +88,8 @@ struct BlockwiseGenericTensorSliceCopy_v5
} }
} }
template <typename ThreadBufferData, typename BlockDstData> template <typename BlockDstData>
__device__ void RunStoreThreadBuffer(const ThreadBufferData* p_thread_buffer, __device__ void RunStoreThreadBuffer(BlockDstData* p_block_dst)
BlockDstData* p_block_dst)
{ {
if(BlockSize == mThreadClusterDesc.GetElementSize() or if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize()) get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
...@@ -109,13 +107,11 @@ struct BlockwiseGenericTensorSliceCopy_v5 ...@@ -109,13 +107,11 @@ struct BlockwiseGenericTensorSliceCopy_v5
"to use ThreadBufferAddressSpace as their thread buffer, which is not vgpr. " "to use ThreadBufferAddressSpace as their thread buffer, which is not vgpr. "
"Behavior may be different"); "Behavior may be different");
BlockSrcData p_thread_buffer[GetThreadBufferSize()];
if(BlockSize == mThreadClusterDesc.GetElementSize() or if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize()) get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{ {
RunLoadThreadBuffer(p_block_src, p_thread_buffer); RunLoadThreadBuffer(p_block_src);
RunStoreThreadBuffer(p_thread_buffer, p_block_dst); RunStoreThreadBuffer(p_block_dst);
} }
} }
......
...@@ -127,7 +127,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -127,7 +127,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize, BlockwiseGenericTensorSliceCopy_v5<BlockSize,
decltype(a_k_m_global_desc), decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc), decltype(a_k_m_block_desc),
decltype(a_k_m_block_desc.GetLengths()), decltype(a_k_m_block_desc.GetLengths()),
...@@ -253,24 +253,21 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -253,24 +253,21 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
Float* p_b_block_next = Float* p_b_block_next =
even_loop ? p_b_block_double + b_block_space : p_b_block_double; even_loop ? p_b_block_double + b_block_space : p_b_block_double;
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True); a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True); b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer); a_blockwise_copy.RunLoadThreadBuffer(p_a_global);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer); b_blockwise_copy.RunLoadThreadBuffer(p_b_global);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread); blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next); a_blockwise_copy.RunStoreThreadBuffer(p_a_block_next);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next); b_blockwise_copy.RunStoreThreadBuffer(p_b_block_next);
} }
} }
...@@ -281,7 +278,6 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -281,7 +278,6 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
if(has_two_iteration_left) // if has 2 iteration left if(has_two_iteration_left) // if has 2 iteration left
{ {
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()]; Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True); a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True); b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
...@@ -289,17 +285,15 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -289,17 +285,15 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer); a_blockwise_copy.RunLoadThreadBuffer(p_a_global);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer); b_blockwise_copy.RunLoadThreadBuffer(p_b_global);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, a_blockwise_copy.RunStoreThreadBuffer(p_a_block_double + a_block_space);
p_a_block_double + a_block_space); b_blockwise_copy.RunStoreThreadBuffer(p_b_block_double + b_block_space);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer,
p_b_block_double + b_block_space);
__syncthreads(); __syncthreads();
......
...@@ -177,7 +177,7 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -177,7 +177,7 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>; using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment