Commit ef0c2e64 authored by Jing Zhang's avatar Jing Zhang
Browse files

move thread_buff into blockcopy

parent 7cf350d6
......@@ -78,9 +78,8 @@ struct BlockwiseGenericTensorSliceCopy_v5
return ThreadBufferDesc::GetElementSpace();
}
template <typename BlockSrcData, typename ThreadBuffData>
__device__ void RunLoadThreadBuffer(const BlockSrcData* p_block_src,
ThreadBuffData& thread_buff)
template <typename BlockSrcData>
__device__ void RunLoadThreadBuffer(const BlockSrcData* p_block_src)
{
if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
......@@ -89,8 +88,8 @@ struct BlockwiseGenericTensorSliceCopy_v5
}
}
template <typename ThreadBuffData, typename BlockDstData>
__device__ void RunStoreThreadBuffer(ThreadBuffData thread_buff, BlockDstData* p_block_dst)
template <typename BlockDstData>
__device__ void RunStoreThreadBuffer(BlockDstData* p_block_dst)
{
if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
......@@ -99,9 +98,8 @@ struct BlockwiseGenericTensorSliceCopy_v5
}
}
template <typename BlockSrcData, typename BlockDstData, typename ThreadBuffData>
__device__ void
Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst, ThreadBuffData& thread_buff)
template <typename BlockSrcData, typename BlockDstData>
__device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst)
{
static_assert(ThreadBufferAddressSpace == AddressSpace::Vgpr,
"wrong! This function use vgpr as its thread "
......@@ -112,8 +110,8 @@ struct BlockwiseGenericTensorSliceCopy_v5
if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
RunLoadThreadBuffer(p_block_src, thread_buff);
RunStoreThreadBuffer(thread_buff, p_block_dst);
RunLoadThreadBuffer(p_block_src);
RunStoreThreadBuffer(p_block_dst);
}
}
......@@ -163,6 +161,9 @@ struct BlockwiseGenericTensorSliceCopy_v5
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
ThreadwiseCopy mThreadwiseCopy;
using ThreadBufferType = decltype(GetRegBuffer<float, GetThreadBufferSize()>());
ThreadBufferType thread_buff;
};
} // namespace ck
......
......@@ -496,18 +496,10 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
constexpr index_t c_thread_size = MPerBlock * NPerBlock / BlockSize;
auto c_thread_vec = GetRegBuffer<AccFloat, c_thread_size>();
using ThreadBufferTypeA =
decltype(GetRegBuffer<ABFloat, a_blockwise_copy.GetThreadBufferSize()>());
using ThreadBufferTypeB =
decltype(GetRegBuffer<ABFloat, b_blockwise_copy.GetThreadBufferSize()>());
ThreadBufferTypeA thread_buff_a;
ThreadBufferTypeB thread_buff_b;
// preload data into LDS
{
a_blockwise_copy.Run(p_a_global, p_a_block, thread_buff_a);
b_blockwise_copy.Run(p_b_global, p_b_block, thread_buff_b);
a_blockwise_copy.Run(p_a_global, p_a_block);
b_blockwise_copy.Run(p_b_global, p_b_block);
}
constexpr auto blockwise_a_copy_src_step = Sequence<0, KPerBlock, 0, 0>{};
......@@ -521,8 +513,8 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
a_blockwise_copy.MoveSrcSliceWindow(blockwise_a_copy_src_step, True);
b_blockwise_copy.MoveSrcSliceWindow(blockwise_b_copy_src_step, True);
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, thread_buff_a);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, thread_buff_b);
a_blockwise_copy.RunLoadThreadBuffer(p_a_global);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global);
block_sync_lds();
......@@ -539,8 +531,8 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
block_sync_lds();
// store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(thread_buff_a, p_a_block);
b_blockwise_copy.RunStoreThreadBuffer(thread_buff_b, p_b_block);
a_blockwise_copy.RunStoreThreadBuffer(p_a_block);
b_blockwise_copy.RunStoreThreadBuffer(p_b_block);
}
// tail
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment