Commit ef0c2e64 authored by Jing Zhang's avatar Jing Zhang
Browse files

move thread_buff into blockcopy

parent 7cf350d6
...@@ -78,9 +78,8 @@ struct BlockwiseGenericTensorSliceCopy_v5 ...@@ -78,9 +78,8 @@ struct BlockwiseGenericTensorSliceCopy_v5
return ThreadBufferDesc::GetElementSpace(); return ThreadBufferDesc::GetElementSpace();
} }
template <typename BlockSrcData, typename ThreadBuffData> template <typename BlockSrcData>
__device__ void RunLoadThreadBuffer(const BlockSrcData* p_block_src, __device__ void RunLoadThreadBuffer(const BlockSrcData* p_block_src)
ThreadBuffData& thread_buff)
{ {
if(BlockSize == mThreadClusterDesc.GetElementSize() or if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize()) get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
...@@ -89,8 +88,8 @@ struct BlockwiseGenericTensorSliceCopy_v5 ...@@ -89,8 +88,8 @@ struct BlockwiseGenericTensorSliceCopy_v5
} }
} }
template <typename ThreadBuffData, typename BlockDstData> template <typename BlockDstData>
__device__ void RunStoreThreadBuffer(ThreadBuffData thread_buff, BlockDstData* p_block_dst) __device__ void RunStoreThreadBuffer(BlockDstData* p_block_dst)
{ {
if(BlockSize == mThreadClusterDesc.GetElementSize() or if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize()) get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
...@@ -99,9 +98,8 @@ struct BlockwiseGenericTensorSliceCopy_v5 ...@@ -99,9 +98,8 @@ struct BlockwiseGenericTensorSliceCopy_v5
} }
} }
template <typename BlockSrcData, typename BlockDstData, typename ThreadBuffData> template <typename BlockSrcData, typename BlockDstData>
__device__ void __device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst)
Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst, ThreadBuffData& thread_buff)
{ {
static_assert(ThreadBufferAddressSpace == AddressSpace::Vgpr, static_assert(ThreadBufferAddressSpace == AddressSpace::Vgpr,
"wrong! This function use vgpr as its thread " "wrong! This function use vgpr as its thread "
...@@ -112,8 +110,8 @@ struct BlockwiseGenericTensorSliceCopy_v5 ...@@ -112,8 +110,8 @@ struct BlockwiseGenericTensorSliceCopy_v5
if(BlockSize == mThreadClusterDesc.GetElementSize() or if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize()) get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{ {
RunLoadThreadBuffer(p_block_src, thread_buff); RunLoadThreadBuffer(p_block_src);
RunStoreThreadBuffer(thread_buff, p_block_dst); RunStoreThreadBuffer(p_block_dst);
} }
} }
...@@ -163,6 +161,9 @@ struct BlockwiseGenericTensorSliceCopy_v5 ...@@ -163,6 +161,9 @@ struct BlockwiseGenericTensorSliceCopy_v5
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
ThreadwiseCopy mThreadwiseCopy; ThreadwiseCopy mThreadwiseCopy;
using ThreadBufferType = decltype(GetRegBuffer<float, GetThreadBufferSize()>());
ThreadBufferType thread_buff;
}; };
} // namespace ck } // namespace ck
......
...@@ -496,18 +496,10 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2 ...@@ -496,18 +496,10 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
constexpr index_t c_thread_size = MPerBlock * NPerBlock / BlockSize; constexpr index_t c_thread_size = MPerBlock * NPerBlock / BlockSize;
auto c_thread_vec = GetRegBuffer<AccFloat, c_thread_size>(); auto c_thread_vec = GetRegBuffer<AccFloat, c_thread_size>();
using ThreadBufferTypeA =
decltype(GetRegBuffer<ABFloat, a_blockwise_copy.GetThreadBufferSize()>());
using ThreadBufferTypeB =
decltype(GetRegBuffer<ABFloat, b_blockwise_copy.GetThreadBufferSize()>());
ThreadBufferTypeA thread_buff_a;
ThreadBufferTypeB thread_buff_b;
// preload data into LDS // preload data into LDS
{ {
a_blockwise_copy.Run(p_a_global, p_a_block, thread_buff_a); a_blockwise_copy.Run(p_a_global, p_a_block);
b_blockwise_copy.Run(p_b_global, p_b_block, thread_buff_b); b_blockwise_copy.Run(p_b_global, p_b_block);
} }
constexpr auto blockwise_a_copy_src_step = Sequence<0, KPerBlock, 0, 0>{}; constexpr auto blockwise_a_copy_src_step = Sequence<0, KPerBlock, 0, 0>{};
...@@ -521,8 +513,8 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2 ...@@ -521,8 +513,8 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
a_blockwise_copy.MoveSrcSliceWindow(blockwise_a_copy_src_step, True); a_blockwise_copy.MoveSrcSliceWindow(blockwise_a_copy_src_step, True);
b_blockwise_copy.MoveSrcSliceWindow(blockwise_b_copy_src_step, True); b_blockwise_copy.MoveSrcSliceWindow(blockwise_b_copy_src_step, True);
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, thread_buff_a); a_blockwise_copy.RunLoadThreadBuffer(p_a_global);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, thread_buff_b); b_blockwise_copy.RunLoadThreadBuffer(p_b_global);
block_sync_lds(); block_sync_lds();
...@@ -539,8 +531,8 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2 ...@@ -539,8 +531,8 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
block_sync_lds(); block_sync_lds();
// store next data to LDS // store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(thread_buff_a, p_a_block); a_blockwise_copy.RunStoreThreadBuffer(p_a_block);
b_blockwise_copy.RunStoreThreadBuffer(thread_buff_b, p_b_block); b_blockwise_copy.RunStoreThreadBuffer(p_b_block);
} }
// tail // tail
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment