"docs/vscode:/vscode.git/clone" did not exist on "e464265cf842c2fa2af2e0d37a0d61391faa4745"
Commit 3b3cfae5 authored by Chao Liu's avatar Chao Liu
Browse files

add blockwise copy that doesn't has thread buffer as member to avoid alloca...

add blockwise copy that doesn't has thread buffer as member to avoid alloca and therefore scratch mem
parent 54138dc8
...@@ -174,37 +174,37 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -174,37 +174,37 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
// GEMM // GEMM
using gridwise_gemm = using gridwise_gemm =
GridwiseDynamicGemm_km_kn_mn_v1<BlockSize, GridwiseDynamicGemm_km_kn_mn_v1r2<BlockSize,
Float, Float,
AccFloat, AccFloat,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
GemmMPerThread, GemmMPerThread,
GemmNPerThread, GemmNPerThread,
GemmKPerThread, GemmKPerThread,
GemmMLevel0Cluster, GemmMLevel0Cluster,
GemmNLevel0Cluster, GemmNLevel0Cluster,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM, GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM, GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>, Sequence<1, 0>,
Sequence<1, 0>, Sequence<1, 0>,
0, 0,
GemmABlockTransferSrcScalarPerVector_GemmK, GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM, GemmABlockTransferDstScalarPerVector_GemmM,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>, Sequence<0, 1>,
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
GemmBBlockTransferSrcScalarPerVector_GemmN, GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN, GemmBBlockTransferDstScalarPerVector_GemmN,
Sequence<2, 3, 0, 1>, Sequence<2, 3, 0, 1>,
3, 3,
GemmCThreadTransferDstScalarPerVector_GemmN1>; GemmCThreadTransferDstScalarPerVector_GemmN1>;
const index_t GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock); const index_t GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
......
...@@ -121,7 +121,7 @@ struct BlockwiseDynamicTensorSliceTransfer_v1r1 ...@@ -121,7 +121,7 @@ struct BlockwiseDynamicTensorSliceTransfer_v1r1
ThreadwiseTransfer threadwise_transfer_; ThreadwiseTransfer threadwise_transfer_;
}; };
// this version has scratch memory issue, due to: // this version is very likely to have scratch memory issue, due to:
// 1. ThreadwiseDynamicTensorSliceTransfer_v1r1 keeps reference to tensor descriptor // 1. ThreadwiseDynamicTensorSliceTransfer_v1r1 keeps reference to tensor descriptor
// 2. threadwise_dynamic_tensor_slice_transfer_v1r1 constructs new tensor coordinate // 2. threadwise_dynamic_tensor_slice_transfer_v1r1 constructs new tensor coordinate
template <index_t BlockSize, template <index_t BlockSize,
...@@ -287,7 +287,7 @@ struct BlockwiseDynamicTensorSliceTransfer_v2r1 ...@@ -287,7 +287,7 @@ struct BlockwiseDynamicTensorSliceTransfer_v2r1
BlockSrcData p_thread_buffer_[thread_buffer_element_size_]; BlockSrcData p_thread_buffer_[thread_buffer_element_size_];
}; };
// this version does not have scratch memory issue, due to: // this version does following things to avoid scratch memory issue
// 1. ThreadwiseDynamicTensorSliceTransfer_v1r2 does not keep reference to tensor descriptor // 1. ThreadwiseDynamicTensorSliceTransfer_v1r2 does not keep reference to tensor descriptor
// 2. threadwise_dynamic_tensor_slice_transfer_v1r2 does not construct new tensor coordinate // 2. threadwise_dynamic_tensor_slice_transfer_v1r2 does not construct new tensor coordinate
template <index_t BlockSize, template <index_t BlockSize,
...@@ -462,5 +462,169 @@ struct BlockwiseDynamicTensorSliceTransfer_v2r2 ...@@ -462,5 +462,169 @@ struct BlockwiseDynamicTensorSliceTransfer_v2r2
BlockSrcData p_thread_buffer_[thread_buffer_element_size_]; BlockSrcData p_thread_buffer_[thread_buffer_element_size_];
}; };
// this version does following things to avoid scratch memory issue
// 1. BlockwiseDynamicTensorSliceTransfer_v2r3 doesn't allocate thread buffer (array) as member
// 2. ThreadwiseDynamicTensorSliceTransfer_v1r2 does not keep reference to tensor descriptor
// 3. threadwise_dynamic_tensor_slice_transfer_v1r2 does not construct new tensor coordinate
template <index_t BlockSize,
typename BlockSrcData,
typename BlockDstData,
typename BlockSrcDesc,
typename BlockDstDesc,
typename BlockSliceLengths,
typename ThreadSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorReadDim,
index_t DstVectorWriteDim,
index_t SrcDataPerRead,
index_t DstDataPerWrite,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
InMemoryDataOperation DstInMemOp,
index_t SrcDataStride,
index_t DstDataStride>
struct BlockwiseDynamicTensorSliceTransfer_v2r3
{
static constexpr index_t nDim =
remove_reference_t<remove_cv_t<BlockSrcDesc>>::GetNumOfDimension();
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseDynamicTensorSliceTransfer_v2r3(
const BlockSrcDesc& block_src_desc,
const Index& src_block_slice_origin,
const BlockDstDesc& block_dst_desc,
const Index& dst_block_slice_origin)
: threadwise_read_(block_src_desc,
make_zero_multi_index<nDim>(),
thread_buffer_desc_,
make_zero_multi_index<nDim>()),
threadwise_write_(thread_buffer_desc_,
make_zero_multi_index<nDim>(),
block_dst_desc,
make_zero_multi_index<nDim>())
{
static_assert(
nDim == remove_reference_t<remove_cv_t<BlockSrcDesc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<BlockDstDesc>>::GetNumOfDimension() &&
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_id =
thread_cluster_desc_.CalculateClusterIndex(get_thread_local_1d_id());
const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{};
threadwise_read_.SetSrcSliceOrigin(block_src_desc,
src_block_slice_origin + thread_data_id_begin);
threadwise_read_.SetDstSliceOrigin(thread_buffer_desc_, make_zero_multi_index<nDim>());
threadwise_write_.SetSrcSliceOrigin(thread_buffer_desc_, make_zero_multi_index<nDim>());
threadwise_write_.SetDstSliceOrigin(block_dst_desc,
dst_block_slice_origin + thread_data_id_begin);
}
}
__device__ static constexpr auto CalculateThreadDataBegin()
{
const auto thread_cluster_id =
thread_cluster_desc_.CalculateClusterIndex(get_thread_local_1d_id());
return thread_cluster_id * ThreadSliceLengths{};
}
__device__ void RunRead(const BlockSrcDesc& block_src_desc,
const BlockSrcData* p_block_src,
BlockSrcData* p_thread_buffer)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_read_.Run(block_src_desc, p_block_src, thread_buffer_desc_, p_thread_buffer);
}
}
__device__ void RunWrite(const BlockDstDesc& block_dst_desc,
BlockDstData* p_block_dst,
BlockDstData* p_thread_buffer)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_write_.Run(
thread_buffer_desc_, p_thread_buffer, block_dst_desc, p_block_dst);
}
}
__device__ void MoveSrcSliceWindow(const BlockSrcDesc& block_src_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_read_.MoveSrcSliceWindow(block_src_desc, step);
}
}
__device__ void MoveDstSliceWindow(const BlockDstDesc& block_dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_write_.MoveDstSliceWindow(block_dst_desc, step);
}
}
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
static constexpr auto thread_buffer_desc_ =
make_dynamic_naive_tensor_descriptor_packed<nDim>(to_multi_index(ThreadSliceLengths{}));
using ThreadwiseRead = ThreadwiseDynamicTensorSliceTransfer_v1r2<BlockSrcDesc,
decltype(thread_buffer_desc_),
ThreadSliceLengths,
SrcDimAccessOrder,
SrcVectorReadDim,
SrcDataPerRead,
1,
SrcAddressSpace,
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
SrcDataStride,
1>;
using ThreadwiseWrite = ThreadwiseDynamicTensorSliceTransfer_v1r2<decltype(thread_buffer_desc_),
BlockDstDesc,
ThreadSliceLengths,
DstDimAccessOrder,
DstVectorWriteDim,
1,
DstDataPerWrite,
AddressSpace::Vgpr,
DstAddressSpace,
DstInMemOp,
1,
DstDataStride>;
ThreadwiseRead threadwise_read_;
ThreadwiseWrite threadwise_write_;
};
} // namespace ck } // namespace ck
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment