Commit 760a234f authored by Chao Liu's avatar Chao Liu
Browse files

use StaticallyIndexedArray for buffer in threadwise copy, in order to get rid of alloca in IR

parent 70d06fa9
......@@ -173,39 +173,41 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// GEMM
#if 1
using gridwise_gemm =
GridwiseDynamicGemm_km_kn_mn_v1<BlockSize,
Float,
AccFloat,
InMemoryDataOperation::Set,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1>;
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v1<
BlockSize,
Float,
AccFloat,
InMemoryDataOperation::Set,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
true, // move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1>;
const index_t GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
......@@ -261,63 +263,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
p_out_global,
integral_constant<bool, false>{});
}
#else
using gridwise_gemm =
GridwiseDynamicGemm_km_kn_mn_v2<BlockSize,
Float,
AccFloat,
InMemoryDataOperation::Set,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1>;
const index_t GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global);
#endif
}
};
......
......@@ -9,514 +9,53 @@
namespace ck {
// this version does not have scratch memory issue, which is good, but I don't know why
template <index_t BlockSize,
typename BlockSrcData,
typename BlockDstData,
typename BlockSrcDesc,
typename BlockDstDesc,
typename BlockSliceLengths,
typename ThreadSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcDstDimAccessOrder,
index_t SrcDstVectoReadDim,
index_t SrcDataPerRead,
index_t DstDataPerWrite,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
InMemoryDataOperation DstInMemOp,
index_t SrcDataStride,
index_t DstDataStride>
struct BlockwiseDynamicTensorSliceTransfer_v1r1
{
static constexpr index_t nDim =
remove_reference_t<remove_cv_t<BlockSrcDesc>>::GetNumOfDimension();
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseDynamicTensorSliceTransfer_v1r1(
const BlockSrcDesc& block_src_desc,
const Index& src_block_slice_origin,
const BlockDstDesc& block_dst_desc,
const Index& dst_block_slice_origin)
: threadwise_transfer_(block_src_desc,
make_zero_multi_index<nDim>(),
block_dst_desc,
make_zero_multi_index<nDim>())
{
static_assert(
nDim == remove_reference_t<remove_cv_t<BlockSrcDesc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<BlockDstDesc>>::GetNumOfDimension() &&
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_id =
thread_cluster_desc_.CalculateClusterIndex(get_thread_local_1d_id());
const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{};
threadwise_transfer_.SetSrcSliceOrigin(src_block_slice_origin + thread_data_id_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin);
}
}
__device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.Run(p_block_src, p_block_dst);
}
}
__device__ void MoveSrcSliceWindow(const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(step);
}
}
__device__ void MoveDstSliceWindow(const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer = ThreadwiseDynamicTensorSliceTransfer_v1r1<BlockSrcDesc,
BlockDstDesc,
ThreadSliceLengths,
SrcDstDimAccessOrder,
SrcDstVectoReadDim,
SrcDataPerRead,
DstDataPerWrite,
SrcAddressSpace,
DstAddressSpace,
DstInMemOp,
SrcDataStride,
DstDataStride>;
ThreadwiseTransfer threadwise_transfer_;
};
// this version tend to have scratch memory issue, due to:
// 1. ThreadwiseDynamicTensorSliceTransfer_v1r1 keeps reference to tensor descriptor
// 2. ThreadwiseDynamicTensorSliceTransfer_v1r1::Run() constructs new tensor coordinate
template <index_t BlockSize,
typename BlockSrcData,
typename BlockDstData,
typename BlockSrcDesc,
typename BlockDstDesc,
typename BlockSliceLengths,
typename ThreadSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorReadDim,
index_t DstVectorWriteDim,
index_t SrcDataPerRead,
index_t DstDataPerWrite,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
InMemoryDataOperation DstInMemOp,
index_t SrcDataStride,
index_t DstDataStride>
struct BlockwiseDynamicTensorSliceTransfer_v2r1
{
static constexpr index_t nDim =
remove_reference_t<remove_cv_t<BlockSrcDesc>>::GetNumOfDimension();
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseDynamicTensorSliceTransfer_v2r1(
const BlockSrcDesc& block_src_desc,
const Index& src_block_slice_origin,
const BlockDstDesc& block_dst_desc,
const Index& dst_block_slice_origin)
: threadwise_read_(block_src_desc,
make_zero_multi_index<nDim>(),
thread_buffer_desc_,
make_zero_multi_index<nDim>()),
threadwise_write_(thread_buffer_desc_,
make_zero_multi_index<nDim>(),
block_dst_desc,
make_zero_multi_index<nDim>())
{
static_assert(
nDim == remove_reference_t<remove_cv_t<BlockSrcDesc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<BlockDstDesc>>::GetNumOfDimension() &&
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_id =
thread_cluster_desc_.CalculateClusterIndex(get_thread_local_1d_id());
const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{};
threadwise_read_.SetSrcSliceOrigin(src_block_slice_origin + thread_data_id_begin);
threadwise_read_.SetDstSliceOrigin(make_zero_multi_index<nDim>());
threadwise_write_.SetSrcSliceOrigin(make_zero_multi_index<nDim>());
threadwise_write_.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin);
}
}
__device__ void RunRead(const BlockSrcData* p_block_src)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_read_.Run(p_block_src, p_thread_buffer_);
}
}
__device__ void RunWrite(BlockDstData* p_block_dst)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_write_.Run(p_thread_buffer_, p_block_dst);
}
}
__device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_read_.Run(p_block_src, p_thread_buffer_);
// if there is type conversion, it's done during write
threadwise_write_.Run(p_thread_buffer_, p_block_dst);
}
}
__device__ void MoveSrcSliceWindow(const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_read_.MoveSrcSliceWindow(step);
}
}
__device__ void MoveDstSliceWindow(const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_write_.MoveDstSliceWindow(step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
static constexpr auto thread_buffer_desc_ =
make_dynamic_naive_tensor_descriptor_packed<nDim>(to_multi_index(ThreadSliceLengths{}));
using ThreadwiseRead = ThreadwiseDynamicTensorSliceTransfer_v1r1<BlockSrcDesc,
decltype(thread_buffer_desc_),
ThreadSliceLengths,
SrcDimAccessOrder,
SrcVectorReadDim,
SrcDataPerRead,
1,
SrcAddressSpace,
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
SrcDataStride,
1>;
using ThreadwiseWrite = ThreadwiseDynamicTensorSliceTransfer_v1r1<decltype(thread_buffer_desc_),
BlockDstDesc,
ThreadSliceLengths,
DstDimAccessOrder,
DstVectorWriteDim,
1,
DstDataPerWrite,
AddressSpace::Vgpr,
DstAddressSpace,
DstInMemOp,
1,
DstDataStride>;
ThreadwiseRead threadwise_read_;
ThreadwiseWrite threadwise_write_;
static constexpr index_t thread_buffer_element_size_ =
thread_buffer_desc_.GetElementSpaceSize();
BlockSrcData p_thread_buffer_[thread_buffer_element_size_];
};
// this version does following things to avoid scratch memory issue
// 1. ThreadwiseDynamicTensorSliceTransfer_v1r2 does not keep reference to tensor descriptor
// 2. ThreadwiseDynamicTensorSliceTransfer_v1r2::Run() does not construct new tensor coordinate
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize,
typename BlockSrcData,
typename BlockDstData,
typename BlockSrcDesc,
typename BlockDstDesc,
typename BlockSliceLengths,
typename ThreadSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorReadDim,
index_t DstVectorWriteDim,
index_t SrcDataPerRead,
index_t DstDataPerWrite,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
InMemoryDataOperation DstInMemOp,
index_t SrcDataStride,
index_t DstDataStride>
struct BlockwiseDynamicTensorSliceTransfer_v2r2
{
static constexpr index_t nDim =
remove_reference_t<remove_cv_t<BlockSrcDesc>>::GetNumOfDimension();
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseDynamicTensorSliceTransfer_v2r2(
const BlockSrcDesc& block_src_desc,
const Index& src_block_slice_origin,
const BlockDstDesc& block_dst_desc,
const Index& dst_block_slice_origin)
: threadwise_read_(block_src_desc,
make_zero_multi_index<nDim>(),
thread_buffer_desc_,
make_zero_multi_index<nDim>()),
threadwise_write_(thread_buffer_desc_,
make_zero_multi_index<nDim>(),
block_dst_desc,
make_zero_multi_index<nDim>())
{
static_assert(
nDim == remove_reference_t<remove_cv_t<BlockSrcDesc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<BlockDstDesc>>::GetNumOfDimension() &&
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_id =
thread_cluster_desc_.CalculateClusterIndex(get_thread_local_1d_id());
const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{};
threadwise_read_.SetSrcSliceOrigin(block_src_desc,
src_block_slice_origin + thread_data_id_begin);
threadwise_read_.SetDstSliceOrigin(thread_buffer_desc_, make_zero_multi_index<nDim>());
threadwise_write_.SetSrcSliceOrigin(thread_buffer_desc_, make_zero_multi_index<nDim>());
threadwise_write_.SetDstSliceOrigin(block_dst_desc,
dst_block_slice_origin + thread_data_id_begin);
}
}
__device__ void RunRead(const BlockSrcDesc& block_src_desc, const BlockSrcData* p_block_src)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_read_.Run(
block_src_desc, p_block_src, thread_buffer_desc_, p_thread_buffer_);
}
}
__device__ void RunWrite(const BlockDstDesc& block_dst_desc, BlockDstData* p_block_dst)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_write_.Run(
thread_buffer_desc_, p_thread_buffer_, block_dst_desc, p_block_dst);
}
}
__device__ void Run(const BlockSrcDesc& block_src_desc,
const BlockSrcData* p_block_src,
const BlockDstDesc& block_dst_desc,
BlockDstData* p_block_dst)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_read_.Run(
block_src_desc, p_block_src, thread_buffer_desc_, p_thread_buffer_);
// if there is type conversion, it's done during write
threadwise_write_.Run(
thread_buffer_desc_, p_thread_buffer_, block_dst_desc, p_block_dst);
}
}
__device__ void MoveSrcSliceWindow(const BlockSrcDesc& block_src_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_read_.MoveSrcSliceWindow(block_src_desc, step);
}
}
__device__ void MoveDstSliceWindow(const BlockDstDesc& block_dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_write_.MoveDstSliceWindow(block_dst_desc, step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
static constexpr auto thread_buffer_desc_ =
make_dynamic_naive_tensor_descriptor_packed<nDim>(to_multi_index(ThreadSliceLengths{}));
using ThreadwiseRead = ThreadwiseDynamicTensorSliceTransfer_v1r2<BlockSrcDesc,
decltype(thread_buffer_desc_),
ThreadSliceLengths,
SrcDimAccessOrder,
SrcVectorReadDim,
SrcDataPerRead,
1,
SrcAddressSpace,
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
SrcDataStride,
1>;
using ThreadwiseWrite = ThreadwiseDynamicTensorSliceTransfer_v1r2<decltype(thread_buffer_desc_),
BlockDstDesc,
ThreadSliceLengths,
DstDimAccessOrder,
DstVectorWriteDim,
1,
DstDataPerWrite,
AddressSpace::Vgpr,
DstAddressSpace,
DstInMemOp,
1,
DstDataStride>;
ThreadwiseRead threadwise_read_;
ThreadwiseWrite threadwise_write_;
static constexpr index_t thread_buffer_element_size_ =
thread_buffer_desc_.GetElementSpaceSize();
BlockSrcData p_thread_buffer_[thread_buffer_element_size_];
};
// this version does following things to avoid scratch memory issue
// 1. BlockwiseDynamicTensorSliceTransfer_v2r3 doesn't allocate thread buffer (array) as member
// 2. ThreadwiseDynamicTensorSliceTransfer_v1r2 does not keep reference to tensor descriptor
// 3. ThreadwiseDynamicTensorSliceTransfer_v1r2::Run() does not construct new tensor coordinate
template <index_t BlockSize,
typename BlockSrcData,
typename BlockDstData,
typename BlockSrcDesc,
typename BlockDstDesc,
typename BlockSliceLengths,
typename ThreadSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorReadDim,
index_t DstVectorWriteDim,
index_t SrcDataPerRead,
index_t DstDataPerWrite,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
InMemoryDataOperation DstInMemOp,
index_t SrcDataStride,
index_t DstDataStride,
index_t ThreadTransferMoveBackSrcCoord = true,
index_t ThreadTransferMoveBackDstCoord = true>
struct BlockwiseDynamicTensorSliceTransfer_v2r3
index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector,
index_t ThreadTransferSrcResetCoordinateAfterRun,
index_t ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseDynamicTensorSliceTransfer_v4
{
static constexpr index_t nDim =
remove_reference_t<remove_cv_t<BlockSrcDesc>>::GetNumOfDimension();
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseDynamicTensorSliceTransfer_v2r3(
const BlockSrcDesc& block_src_desc,
const Index& src_block_slice_origin,
const BlockDstDesc& block_dst_desc,
const Index& dst_block_slice_origin)
: threadwise_read_(block_src_desc,
make_zero_multi_index<nDim>(),
thread_buffer_desc_,
make_zero_multi_index<nDim>()),
threadwise_write_(thread_buffer_desc_,
make_zero_multi_index<nDim>(),
block_dst_desc,
make_zero_multi_index<nDim>())
__device__ constexpr BlockwiseDynamicTensorSliceTransfer_v4(const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin)
: threadwise_transfer_(
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
{
static_assert(
nDim == remove_reference_t<remove_cv_t<BlockSrcDesc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<BlockDstDesc>>::GetNumOfDimension() &&
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
......@@ -533,13 +72,10 @@ struct BlockwiseDynamicTensorSliceTransfer_v2r3
const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{};
threadwise_read_.SetSrcSliceOrigin(block_src_desc,
src_block_slice_origin + thread_data_id_begin);
threadwise_read_.SetDstSliceOrigin(thread_buffer_desc_, make_zero_multi_index<nDim>());
threadwise_write_.SetSrcSliceOrigin(thread_buffer_desc_, make_zero_multi_index<nDim>());
threadwise_write_.SetDstSliceOrigin(block_dst_desc,
dst_block_slice_origin + thread_data_id_begin);
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_id_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_id_begin);
}
}
......@@ -551,86 +87,66 @@ struct BlockwiseDynamicTensorSliceTransfer_v2r3
return thread_cluster_id * ThreadSliceLengths{};
}
__device__ void RunRead(const BlockSrcDesc& block_src_desc,
const BlockSrcData* p_block_src,
BlockSrcData* p_thread_buffer)
__device__ void RunRead(const SrcDesc& src_desc, const SrcData* p_src)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_read_.Run(block_src_desc, p_block_src, thread_buffer_desc_, p_thread_buffer);
threadwise_transfer_.RunRead(src_desc, p_src);
}
}
__device__ void RunWrite(const BlockDstDesc& block_dst_desc,
BlockDstData* p_block_dst,
BlockDstData* p_thread_buffer)
__device__ void RunWrite(const DstDesc& dst_desc, DstData* p_dst)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_write_.Run(
thread_buffer_desc_, p_thread_buffer, block_dst_desc, p_block_dst);
threadwise_transfer_.RunWrite(dst_desc, p_dst);
}
}
__device__ void MoveSrcSliceWindow(const BlockSrcDesc& block_src_desc, const Index& step)
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_read_.MoveSrcSliceWindow(block_src_desc, step);
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
}
}
__device__ void MoveDstSliceWindow(const BlockDstDesc& block_dst_desc, const Index& step)
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_write_.MoveDstSliceWindow(block_dst_desc, step);
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
static constexpr auto thread_buffer_desc_ =
make_dynamic_naive_tensor_descriptor_packed<nDim>(to_multi_index(ThreadSliceLengths{}));
using ThreadwiseTransfer =
ThreadwiseDynamicTensorSliceTransfer_v3<ThreadSliceLengths,
DstInMemOp,
SrcData,
DstData,
SrcDesc,
DstDesc,
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVector,
DstScalarPerVector,
SrcScalarStrideInVector,
DstScalarStrideInVector,
SrcAddressSpace,
DstAddressSpace,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>;
using ThreadwiseRead = ThreadwiseDynamicTensorSliceTransfer_v1r2<BlockSrcDesc,
decltype(thread_buffer_desc_),
ThreadSliceLengths,
SrcDimAccessOrder,
SrcVectorReadDim,
SrcDataPerRead,
1,
SrcAddressSpace,
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
SrcDataStride,
1,
ThreadTransferMoveBackSrcCoord,
true>;
using ThreadwiseWrite =
ThreadwiseDynamicTensorSliceTransfer_v1r2<decltype(thread_buffer_desc_),
BlockDstDesc,
ThreadSliceLengths,
DstDimAccessOrder,
DstVectorWriteDim,
1,
DstDataPerWrite,
AddressSpace::Vgpr,
DstAddressSpace,
DstInMemOp,
1,
DstDataStride,
true,
ThreadTransferMoveBackDstCoord>;
ThreadwiseRead threadwise_read_;
ThreadwiseWrite threadwise_write_;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
......
......@@ -32,6 +32,7 @@ template <index_t BlockSize,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N,
typename BBlockTransferThreadClusterLengths_K_N,
typename BBlockTransferThreadClusterArrangeOrder,
......@@ -39,6 +40,7 @@ template <index_t BlockSize,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
......@@ -130,28 +132,28 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// A matrix blockwise copy
auto a_block_copy =
BlockwiseDynamicTensorSliceTransfer_v2r3<BlockSize,
Float,
Float,
decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc),
Sequence<KPerBlock, MPerBlock>,
ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
Sequence<0, 1>,
ABlockTransferSrcVectorDim,
1,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M,
AddressSpace::Global,
AddressSpace::Lds,
InMemoryDataOperation::Set,
1,
1,
true,
true>(
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, MPerBlock>,
ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder,
Float,
Float,
decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1>,
ABlockTransferSrcVectorDim,
1,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M,
AddressSpace::Global,
AddressSpace::Lds,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_k_m_global_desc,
make_multi_index(0, m_block_data_on_global),
a_k_m_block_desc,
......@@ -159,32 +161,28 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// B matrix blockwise copy
auto b_block_copy =
BlockwiseDynamicTensorSliceTransfer_v2r3<BlockSize,
Float,
Float,
decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc),
Sequence<KPerBlock, NPerBlock>,
BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
Sequence<0, 1>,
BBlockTransferSrcVectorDim,
1,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N,
AddressSpace::Global,
AddressSpace::Lds,
InMemoryDataOperation::Set,
1,
1,
#if 0
true.
#else
false,
#endif
true>(
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, NPerBlock>,
BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder,
Float,
Float,
decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 1>,
BBlockTransferSrcVectorDim,
1,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N,
AddressSpace::Global,
AddressSpace::Lds,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_k_n_global_desc,
make_multi_index(0, n_block_data_on_global),
b_k_n_block_desc,
......@@ -253,25 +251,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
#if 0
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
#else
// HACK: fuse threadwise copy move-back coordinate with move src slice window
constexpr auto b_block_slice_copy_step =
b_block_copy.threadwise_read_.GetCoordinateStepBack() + make_multi_index(KPerBlock, 0);
#endif
// LDS double buffer: preload data into LDS
{
Float p_a_thread_buffer[a_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
Float p_b_thread_buffer[b_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
a_block_copy.RunRead(a_k_m_global_desc, p_a_global);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global);
a_block_copy.RunRead(a_k_m_global_desc, p_a_global, p_a_thread_buffer);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global, p_b_thread_buffer);
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_double, p_a_thread_buffer);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_double, p_b_thread_buffer);
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_double);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_double);
}
// LDS double buffer: main body
......@@ -298,19 +286,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
__syncthreads();
Float p_a_thread_buffer[a_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
Float p_b_thread_buffer[b_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
// LDS doubel buffer: load next data from device mem
a_block_copy.RunRead(a_k_m_global_desc, p_a_global, p_a_thread_buffer);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global, p_b_thread_buffer);
a_block_copy.RunRead(a_k_m_global_desc, p_a_global);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global);
// LDS double buffer: GEMM on current data
block_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
// LDS double buffer: store next data to LDS
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_next, p_a_thread_buffer);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_next, p_b_thread_buffer);
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_next);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_next);
}
}
......@@ -323,21 +308,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
__syncthreads();
Float p_a_thread_buffer[a_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
Float p_b_thread_buffer[b_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
// LDS double buffer: load last data from device mem
a_block_copy.RunRead(a_k_m_global_desc, p_a_global, p_a_thread_buffer);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global, p_b_thread_buffer);
a_block_copy.RunRead(a_k_m_global_desc, p_a_global);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global);
// LDS double buffer: GEMM on 2nd-last data
block_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS
a_block_copy.RunWrite(
a_k_m_block_desc, p_a_block_double + a_block_space_size, p_a_thread_buffer);
b_block_copy.RunWrite(
b_k_n_block_desc, p_b_block_double + b_block_space_size, p_b_thread_buffer);
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_double + b_block_space_size);
__syncthreads();
......@@ -378,6 +358,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
n_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseDynamicTensorSliceTransfer_v1r2<
AccFloat,
Float,
decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc),
Sequence<MRepeat, MPerThread, NRepeat, NPerThread>,
......@@ -389,13 +371,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
AddressSpace::Global,
CGlobalMemoryDataOperation,
1,
1>(c_m0_m1_n0_n1_thread_desc,
make_multi_index(0, 0, 0, 0),
c_m0_m1_n0_n1_global_desc,
make_multi_index(m_thread_data_on_global / M1,
m_thread_data_on_global % M1,
n_thread_data_on_global / N1,
n_thread_data_on_global % N1))
1,
true,
true>(c_m0_m1_n0_n1_thread_desc,
make_multi_index(0, 0, 0, 0),
c_m0_m1_n0_n1_global_desc,
make_multi_index(m_thread_data_on_global / M1,
m_thread_data_on_global % M1,
n_thread_data_on_global / N1,
n_thread_data_on_global % N1))
.Run(c_m0_m1_n0_n1_thread_desc, p_c_thread, c_m0_m1_n0_n1_global_desc, p_c_global);
}
}
......@@ -423,368 +407,5 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
integral_constant<bool, IsEvenNumberKBlockLoop>{});
}
};
template <index_t BlockSize,
typename Float,
typename AccFloat,
InMemoryDataOperation CGlobalMemoryDataOperation,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M,
typename BBlockTransferThreadSliceLengths_K_N,
typename BBlockTransferThreadClusterLengths_K_N,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
struct GridwiseDynamicGemm_km_kn_mn_v2
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr index_t max_lds_align = math::lcm(ABlockTransferDstScalarPerVector_M,
BBlockTransferDstScalarPerVector_N,
MPerThread,
NPerThread);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned<2>(
make_multi_index(KPerBlock, MPerBlock), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned<2>(
make_multi_index(KPerBlock, NPerBlock), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr index_t b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size + b_block_space_size) * sizeof(Float);
}
template <typename... ADesc, typename... BDesc, typename... CDesc>
__device__ void Run(const DynamicTensorDescriptor<ADesc...>& a_k_m_global_desc,
const Float* __restrict__ p_a_global,
const DynamicTensorDescriptor<BDesc...>& b_k_n_global_desc,
const Float* __restrict__ p_b_global,
const DynamicTensorDescriptor<CDesc...>& c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global,
Float* __restrict__ p_shared_block) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
const index_t K = a_k_m_global_desc.GetLength(I0);
const index_t M = a_k_m_global_desc.GetLength(I1);
const index_t N = b_k_n_global_desc.GetLength(I1);
// divide block work by [M, N]
#if 0
const index_t m_block_work_num = M / MPerBlock;
const index_t n_block_work_num = N / NPerBlock;
#else
// Hack: this force result into SGPR
const index_t m_block_work_num = __builtin_amdgcn_readfirstlane(M / MPerBlock);
const index_t n_block_work_num = __builtin_amdgcn_readfirstlane(N / NPerBlock);
#endif
#if 0
const index_t m_block_work_id = get_block_1d_id() / n_block_work_num;
const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num;
#else
// Hack: this force result into SGPR
const index_t m_block_work_id =
__builtin_amdgcn_readfirstlane(get_block_1d_id() / n_block_work_num);
const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num;
#endif
const index_t m_block_data_on_global = m_block_work_id * MPerBlock;
const index_t n_block_data_on_global = n_block_work_id * NPerBlock;
// lds max alignment
constexpr index_t max_lds_align = math::lcm(ABlockTransferDstScalarPerVector_M,
BBlockTransferDstScalarPerVector_N,
MPerThread,
NPerThread);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned<2>(
make_multi_index(KPerBlock, MPerBlock), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned<2>(
make_multi_index(KPerBlock, NPerBlock), max_lds_align);
// A matrix blockwise copy
auto a_block_copy =
BlockwiseDynamicTensorSliceTransfer_v2r3<BlockSize,
Float,
Float,
decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc),
Sequence<KPerBlock, MPerBlock>,
ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
Sequence<0, 1>,
ABlockTransferSrcVectorDim,
1,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M,
AddressSpace::Global,
AddressSpace::Lds,
InMemoryDataOperation::Set,
1,
1,
true,
true>(
a_k_m_global_desc,
make_multi_index(0, m_block_data_on_global),
a_k_m_block_desc,
make_multi_index(0, 0));
// B matrix blockwise copy
auto b_block_copy =
BlockwiseDynamicTensorSliceTransfer_v2r3<BlockSize,
Float,
Float,
decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc),
Sequence<KPerBlock, NPerBlock>,
BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
Sequence<0, 1>,
BBlockTransferSrcVectorDim,
1,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N,
AddressSpace::Global,
AddressSpace::Lds,
InMemoryDataOperation::Set,
1,
1,
#if 0
true.
#else
false,
#endif
true>(
b_k_n_global_desc,
make_multi_index(0, n_block_data_on_global),
b_k_n_block_desc,
make_multi_index(0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr index_t a_k_m_block_mtx_stride =
a_k_m_block_desc.CalculateOffset(make_multi_index(1, 0)) -
a_k_m_block_desc.CalculateOffset(make_multi_index(0, 0));
constexpr index_t b_k_n_block_mtx_stride =
b_k_n_block_desc.CalculateOffset(make_multi_index(1, 0)) -
b_k_n_block_desc.CalculateOffset(make_multi_index(0, 0));
constexpr auto a_k_m_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<KPerBlock>{}, Number<MPerBlock>{}, Number<a_k_m_block_mtx_stride>{});
constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<KPerBlock>{}, Number<NPerBlock>{}, Number<b_k_n_block_mtx_stride>{});
// sanity check
static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
"wrong!");
constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
constexpr index_t NRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<MRepeat * MPerThread>{}, Number<NRepeat * NPerThread>{});
const auto block_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc),
decltype(c_m0m1_n0n1_thread_mtx_desc),
MPerThread,
NPerThread,
KPerThread,
MLevel0Cluster,
NLevel0Cluster,
MLevel1Cluster,
NLevel1Cluster,
MPerThread,
NPerThread>{};
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr index_t b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
Float* p_a_block = p_shared_block;
Float* p_b_block = p_shared_block + a_block_space_size;
// register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
#if 0
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
#else
// HACK: fuse threadwise copy move-back coordinate with move src slice window
constexpr auto b_block_slice_copy_step =
b_block_copy.threadwise_read_.GetCoordinateStepBack() + make_multi_index(KPerBlock, 0);
#endif
// preload data into LDS
{
Float p_a_thread_buffer[a_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
Float p_b_thread_buffer[b_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
a_block_copy.RunRead(a_k_m_global_desc, p_a_global, p_a_thread_buffer);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global, p_b_thread_buffer);
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block, p_a_thread_buffer);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block, p_b_thread_buffer);
}
// main body
for(index_t k_block_data_begin = 0; k_block_data_begin < K - KPerBlock;
k_block_data_begin += KPerBlock)
{
Float p_a_thread_buffer[a_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
Float p_b_thread_buffer[b_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
a_block_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_block_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
// load next data from device mem
a_block_copy.RunRead(a_k_m_global_desc, p_a_global, p_a_thread_buffer);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global, p_b_thread_buffer);
__syncthreads();
// GEMM on current data
block_gemm.Run(p_a_block, p_b_block, p_c_thread);
__syncthreads();
// store next data to LDS
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block, p_a_thread_buffer);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block, p_b_thread_buffer);
}
// tail
{
__syncthreads();
block_gemm.Run(p_a_block, p_b_block, p_c_thread);
}
// output: register to global memory
{
constexpr index_t M1 = MPerThread * MLevel0Cluster * MLevel1Cluster;
constexpr index_t N1 = NPerThread * NLevel0Cluster * NLevel1Cluster;
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr auto c_m0_m1_n0_n1_thread_desc =
make_dynamic_naive_tensor_descriptor_packed<4>(
make_multi_index(MRepeat, MPerThread, NRepeat, NPerThread));
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
block_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row;
const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseDynamicTensorSliceTransfer_v1r2<
decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc),
Sequence<MRepeat, MPerThread, NRepeat, NPerThread>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
1,
CThreadTransferDstScalarPerVector,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation,
1,
1>(c_m0_m1_n0_n1_thread_desc,
make_multi_index(0, 0, 0, 0),
c_m0_m1_n0_n1_global_desc,
make_multi_index(m_thread_data_on_global / M1,
m_thread_data_on_global % M1,
n_thread_data_on_global / N1,
n_thread_data_on_global % N1))
.Run(c_m0_m1_n0_n1_thread_desc, p_c_thread, c_m0_m1_n0_n1_global_desc, p_c_global);
}
}
template <typename... ADesc, typename... BDesc, typename... CDesc>
__device__ void Run(const DynamicTensorDescriptor<ADesc...>& a_k_m_global_desc,
const Float* __restrict__ p_a_global,
const DynamicTensorDescriptor<BDesc...>& b_k_n_global_desc,
const Float* __restrict__ p_b_global,
const DynamicTensorDescriptor<CDesc...>& c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global) const
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
__shared__ Float p_shared_block[shared_block_size];
Run(a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
p_shared_block);
}
};
} // namespace ck
#endif
......@@ -7,185 +7,12 @@
namespace ck {
// this version tends to have scratch memory issue, due to:
// 1. It keeps reference to tensor descriptor
// 2. It constructs new tensor coordinate in this->Run()
template <typename SrcDesc,
typename DstDesc,
typename SliceLengths,
typename SrcDstDimAccessOrder,
index_t SrcDstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
InMemoryDataOperation DstInMemOp,
index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector>
struct ThreadwiseDynamicTensorSliceTransfer_v1r1
{
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_dynamic_tensor_coordinate(DstDesc{}, Index{}));
using SrcCoordStep = decltype(make_dynamic_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_dynamic_tensor_coordinate_step(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v1r1(const SrcDesc& src_desc,
const Index& src_slice_origin,
const DstDesc& dst_desc,
const Index& dst_slice_origin)
: src_desc_(src_desc),
src_slice_origin_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin)),
dst_desc_(dst_desc),
dst_slice_origin_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin))
{
}
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v1r1()
: ThreadwiseDynamicTensorSliceTransfer_v1r1(
SrcDesc{}, make_zero_multi_index<nDim>(), DstDesc{}, make_zero_multi_index<nDim>())
{
}
template <typename SrcData, typename DstData>
__device__ void Run(const SrcData* p_src, DstData* p_dst) const
{
// comment: construction tensor coordinate here tends to cause scratch memory issue
auto src_coord = src_slice_origin_;
auto dst_coord = dst_slice_origin_;
// TODO use constexpr for coordinate-step to make sure compiler behave correctly
const auto src_step_0_p1 =
make_dynamic_tensor_coordinate_step(src_desc_, make_multi_index(0, 1));
const auto src_step_0_m1 =
make_dynamic_tensor_coordinate_step(src_desc_, make_multi_index(0, -1));
const auto src_step_p1_0 =
make_dynamic_tensor_coordinate_step(src_desc_, make_multi_index(1, 0));
const auto src_step_m1_0 =
make_dynamic_tensor_coordinate_step(src_desc_, make_multi_index(-1, 0));
const auto dst_step_0_p1 =
make_dynamic_tensor_coordinate_step(dst_desc_, make_multi_index(0, 1));
const auto dst_step_0_m1 =
make_dynamic_tensor_coordinate_step(dst_desc_, make_multi_index(0, -1));
const auto dst_step_p1_0 =
make_dynamic_tensor_coordinate_step(dst_desc_, make_multi_index(1, 0));
const auto dst_step_m1_0 =
make_dynamic_tensor_coordinate_step(dst_desc_, make_multi_index(-1, 0));
constexpr index_t Len0 = SliceLengths{}[0];
constexpr index_t Len1 = SliceLengths{}[1];
bool forward_dim0 = true;
bool forward_dim1 = true;
// hardcoded for 2d loop for now
#pragma unroll
for(index_t i0 = 0; i0 < Len0; ++i0)
{
#pragma unroll
for(index_t i1 = 0; i1 < Len1; ++i1)
{
// do work
transfer_data<SrcData,
1,
SrcAddressSpace,
DstAddressSpace,
DstInMemOp,
SrcScalarStrideInVector,
DstScalarStrideInVector>(
p_src,
src_coord.GetOffset(),
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc_,
src_coord),
src_desc_.GetElementSpaceSize(),
p_dst,
dst_coord.GetOffset(),
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc_,
dst_coord),
dst_desc_.GetElementSpaceSize());
// move dim1 iterator
if(i1 < Len1 - 1)
{
if(forward_dim1)
{
move_dynamic_tensor_coordinate(src_desc_, src_coord, src_step_0_p1);
move_dynamic_tensor_coordinate(dst_desc_, dst_coord, dst_step_0_p1);
}
else
{
move_dynamic_tensor_coordinate(src_desc_, src_coord, src_step_0_m1);
move_dynamic_tensor_coordinate(dst_desc_, dst_coord, dst_step_0_m1);
}
}
}
// switch dim1 iteration direction
forward_dim1 = !forward_dim1;
// move dim0 iterator
if(i0 < Len0 - 1)
{
if(forward_dim0)
{
move_dynamic_tensor_coordinate(src_desc_, src_coord, src_step_p1_0);
move_dynamic_tensor_coordinate(dst_desc_, dst_coord, dst_step_p1_0);
}
else
{
move_dynamic_tensor_coordinate(src_desc_, src_coord, src_step_m1_0);
move_dynamic_tensor_coordinate(dst_desc_, dst_coord, dst_step_m1_0);
}
}
}
}
__device__ void SetSrcSliceOrigin(const Index& src_slice_origin_idx)
{
src_slice_origin_ = make_dynamic_tensor_coordinate(src_desc_, src_slice_origin_idx);
}
__device__ void SetDstSliceOrigin(const Index& dst_slice_origin_idx)
{
dst_slice_origin_ = make_dynamic_tensor_coordinate(dst_desc_, dst_slice_origin_idx);
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const Index& src_slice_origin_step_idx)
{
// is it OK to construct a new step every time?
const auto src_slice_origin_step =
make_dynamic_tensor_coordinate_step(src_desc_, src_slice_origin_step_idx);
move_dynamic_tensor_coordinate(src_desc_, src_slice_origin_, src_slice_origin_step);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const Index& dst_slice_origin_step_idx)
{
// is it OK to construct a new step every time?
const auto dst_slice_origin_step =
make_dynamic_tensor_coordinate_step(dst_desc_, dst_slice_origin_step_idx);
move_dynamic_tensor_coordinate(dst_desc_, dst_slice_origin_, dst_slice_origin_step);
}
private:
const SrcDesc& src_desc_;
const DstDesc& dst_desc_;
SrcCoord src_slice_origin_;
DstCoord dst_slice_origin_;
};
// this version is less likely to have scratch memory issue, due to:
// 1. It does not keep reference to tensor descriptor
// 2. It does not construct new tensor coordinate for this->Run()
template <typename SrcDesc,
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SliceLengths,
typename SrcDstDimAccessOrder,
......@@ -197,8 +24,12 @@ template <typename SrcDesc,
InMemoryDataOperation DstInMemOp,
index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector,
bool MoveBackSrcCoord = true,
bool MoveBackDstCoord = true>
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
// RunRead(), will be fused with MoveSrcSliceWindow to
// save addr computation
bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation
struct ThreadwiseDynamicTensorSliceTransfer_v1r2
{
static constexpr index_t nDim = SliceLengths::Size();
......@@ -225,7 +56,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
{
}
template <typename SrcData, typename DstData>
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
src_slice_origin_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx);
}
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
{
dst_slice_origin_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx);
}
__device__ void
Run(const SrcDesc& src_desc, const SrcData* p_src, const DstDesc& dst_desc, DstData* p_dst)
{
......@@ -256,13 +96,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
constexpr index_t Len1 = SliceLengths{}[1];
#pragma unroll
for(index_t i0 = 0; i0 < Len0; ++i0)
for(index_t iter0 = 0; iter0 < Len0; ++iter0)
{
#pragma unroll
for(index_t i1 = 0; i1 < Len1; ++i1)
for(index_t iter1 = 0; iter1 < Len1; ++iter1)
{
#if 1 // debug
// do work
// do work
transfer_data<SrcData,
1,
SrcAddressSpace,
......@@ -280,68 +119,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
coordinate_has_valid_offset_assuming_visible_index_is_valid(
dst_desc, dst_slice_origin_),
dst_desc.GetElementSpaceSize());
#else
if constexpr(SrcAddressSpace == AddressSpace::Global &&
DstAddressSpace == AddressSpace::Vgpr)
{
if(coordinate_has_valid_offset_assuming_visible_index_is_valid(
dst_desc, dst_slice_origin_))
{
const SrcData tmp = amd_buffer_load<SrcData, 1>(
p_src,
src_slice_origin_.GetOffset(),
coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_),
src_desc.GetElementSpaceSize());
const index_t dst_offset = dst_slice_origin_.GetOffset();
p_dst[dst_offset] = tmp;
}
}
else if constexpr(SrcAddressSpace == AddressSpace::Vgpr &&
DstAddressSpace == AddressSpace::Global)
{
const SrcData zeros = 0;
const bool src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_);
const bool dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(
dst_desc, dst_slice_origin_);
amd_buffer_store<SrcData, 1>(
src_valid ? &(p_src[src_slice_origin_.GetOffset()]) : &zeros,
p_dst,
dst_slice_origin_.GetOffset(),
dst_valid,
dst_desc.GetElementSpaceSize());
}
else
{
if(coordinate_has_valid_offset_assuming_visible_index_is_valid(
dst_desc, dst_slice_origin_))
{
if(coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_))
{
p_dst[dst_slice_origin_.GetOffset()] =
p_src[src_slice_origin_.GetOffset()];
}
else
{
p_dst[dst_slice_origin_.GetOffset()] = 0;
}
}
}
#endif
// move dim1 iterator
if(i1 < Len1 - 1)
if(iter1 < Len1 - 1)
{
bool forward_dim1 = (i0 % 2 == 0);
bool forward_dim1 = (iter0 % 2 == 0);
if(forward_dim1)
{
......@@ -361,7 +143,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
}
// move dim0 iterator
if(i0 < Len0 - 1)
if(iter0 < Len0 - 1)
{
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_, src_step_p1_0);
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_, dst_step_p1_0);
......@@ -416,22 +198,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
constexpr index_t Len2 = SliceLengths{}[2];
constexpr index_t Len3 = SliceLengths{}[3];
bool forward_dim0 = true;
bool forward_dim1 = true;
bool forward_dim2 = true;
bool forward_dim3 = true;
#pragma unroll
for(index_t i0 = 0; i0 < Len0; ++i0)
for(index_t iter0 = 0; iter0 < Len0; ++iter0)
{
#pragma unroll
for(index_t i1 = 0; i1 < Len1; ++i1)
for(index_t iter1 = 0; iter1 < Len1; ++iter1)
{
#pragma unroll
for(index_t i2 = 0; i2 < Len2; ++i2)
for(index_t iter2 = 0; iter2 < Len2; ++iter2)
{
#pragma unroll
for(index_t i3 = 0; i3 < Len3; ++i3)
for(index_t iter3 = 0; iter3 < Len3; ++iter3)
{
// do work
transfer_data<SrcData,
......@@ -453,8 +230,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
dst_desc.GetElementSpaceSize());
// move dim1 iterator
if(i3 < Len3 - 1)
if(iter3 < Len3 - 1)
{
bool forward_dim3 = (iter2 % 2 == 0);
if(forward_dim3)
{
move_dynamic_tensor_coordinate(
......@@ -472,12 +251,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
}
}
// switch dim3 iteration direction
forward_dim3 = !forward_dim3;
// move dim1 iterator
if(i2 < Len2 - 1)
if(iter2 < Len2 - 1)
{
bool forward_dim2 = (iter1 % 2 == 0);
if(forward_dim2)
{
move_dynamic_tensor_coordinate(
......@@ -495,12 +273,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
}
}
// switch dim2 iteration direction
forward_dim2 = !forward_dim2;
// move dim1 iterator
if(i1 < Len1 - 1)
if(iter1 < Len1 - 1)
{
bool forward_dim1 = (iter0 % 2 == 0);
if(forward_dim1)
{
move_dynamic_tensor_coordinate(
......@@ -518,59 +295,132 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
}
}
// switch dim1 iteration direction
forward_dim1 = !forward_dim1;
// move dim0 iterator
if(i0 < Len0 - 1)
// move dim0 iterator:
if(iter0 < Len0 - 1)
{
if(forward_dim0)
{
move_dynamic_tensor_coordinate(
src_desc, src_slice_origin_, src_step_p1_0_0_0);
move_dynamic_tensor_coordinate(
dst_desc, dst_slice_origin_, dst_step_p1_0_0_0);
}
else
{
move_dynamic_tensor_coordinate(
src_desc, src_slice_origin_, src_step_m1_0_0_0);
move_dynamic_tensor_coordinate(
dst_desc, dst_slice_origin_, dst_step_m1_0_0_0);
}
// move forward in dim0
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_, src_step_p1_0_0_0);
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_, dst_step_p1_0_0_0);
}
}
}
// move src and dst coordinate back to their origins
if constexpr(MoveBackSrcCoord)
if constexpr(SrcResetCoordinateAfterRun)
{
const auto src_step_back =
make_dynamic_tensor_coordinate_step(src_desc, GetCoordinateStepBack());
const auto src_back_step =
make_dynamic_tensor_coordinate_step(src_desc, GetCoordinateBackStep());
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_, src_step_back);
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_, src_back_step);
}
if constexpr(MoveBackDstCoord)
if constexpr(DstResetCoordinateAfterRun)
{
const auto dst_step_back =
make_dynamic_tensor_coordinate_step(dst_desc, GetCoordinateStepBack());
const auto dst_back_step =
make_dynamic_tensor_coordinate_step(dst_desc, GetCoordinateBackStep());
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_, dst_step_back);
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_, dst_back_step);
}
}
__device__ static constexpr auto GetCoordinateStepBack()
__device__ static constexpr auto GetCoordinateBackStep()
{
MultiIndex<nDim> step_back;
MultiIndex<nDim> back_step;
step_back(Number<0>{}) = 1 - SliceLengths{}[0];
back_step(Number<0>{}) = 1 - SliceLengths{}[0];
static_for<1, nDim, 1>{}([&](auto i) {
step_back(i) = (SliceLengths{}[i - Number<1>{}] % 2 == 0) ? 0 : (1 - SliceLengths{}[i]);
back_step(i) = (SliceLengths{}[i - Number<1>{}] % 2 == 0) ? 0 : (1 - SliceLengths{}[i]);
});
return step_back;
return back_step;
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
// is it OK to construct a new step every time?
const auto src_slice_origin_step =
make_dynamic_tensor_coordinate_step(src_desc, src_slice_origin_step_idx);
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_, src_slice_origin_step);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx)
{
// is it OK to construct a new step every time?
const auto dst_slice_origin_step =
make_dynamic_tensor_coordinate_step(dst_desc, dst_slice_origin_step_idx);
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_, dst_slice_origin_step);
}
private:
SrcCoord src_slice_origin_;
DstCoord dst_slice_origin_;
};
// this version does following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions
// 1. It does not keep reference to tensor descriptor
// 2. It does not construct new tensor coordinate for this->Run()
// 3. It does not use pointer for VGPR thread buffer
// 4. It calculate offset for thread buffer directly, instead of moving the coordinate
template <typename SliceLengths,
InMemoryDataOperation DstInMemOp,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector,
index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
// RunRead(), will be fused with MoveSrcSliceWindow to
// save addr computation
bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation
struct ThreadwiseDynamicTensorSliceTransfer_v3
{
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_dynamic_tensor_coordinate(DstDesc{}, Index{}));
using SrcCoordStep = decltype(make_dynamic_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_dynamic_tensor_coordinate_step(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v3(const SrcDesc& src_desc,
const Index& src_slice_origin,
const DstDesc& dst_desc,
const Index& dst_slice_origin)
: src_slice_origin_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin)),
dst_slice_origin_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin))
{
static_assert(SrcAddressSpace == AddressSpace::Global or
SrcAddressSpace == AddressSpace::Lds,
"wrong!");
static_assert(DstAddressSpace == AddressSpace::Global or
DstAddressSpace == AddressSpace::Lds,
"wrong!");
}
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v3()
: ThreadwiseDynamicTensorSliceTransfer_v3(
SrcDesc{}, make_zero_multi_index<nDim>(), DstDesc{}, make_zero_multi_index<nDim>())
{
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
......@@ -583,15 +433,188 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
dst_slice_origin_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx);
}
__device__ void RunRead(const SrcDesc& src_desc, const SrcData* p_src)
{
static_assert(remove_reference_t<SrcDesc>::GetNumOfDimension() == 2,
"wrong! hardcoded for 2D tensor");
// hardcoded for 2D
// TODO implemente N-D
if constexpr(remove_reference_t<SrcDesc>::GetNumOfDimension() == 2)
{
// TODO use constexpr for coordinate-step to make sure compiler behave correctly
const auto src_step_0_p1 =
make_dynamic_tensor_coordinate_step(src_desc, make_multi_index(0, 1));
const auto src_step_0_m1 =
make_dynamic_tensor_coordinate_step(src_desc, make_multi_index(0, -1));
const auto src_step_p1_0 =
make_dynamic_tensor_coordinate_step(src_desc, make_multi_index(1, 0));
const auto src_step_m1_0 =
make_dynamic_tensor_coordinate_step(src_desc, make_multi_index(-1, 0));
constexpr index_t Len0 = SliceLengths{}[0];
constexpr index_t Len1 = SliceLengths{}[1];
static_for<0, Len0, 1>{}([&](auto iter0) {
static_for<0, Len1, 1>{}([&](auto iter1) {
// step direction
constexpr bool forward_dim1 = (iter0.value % 2 == 0);
constexpr index_t i0 = iter0.value;
constexpr index_t i1 = forward_dim1 ? iter1.value : Len1 - iter1.value - 1;
// do work
constexpr index_t buffer_offset =
buffer_desc_.CalculateOffset(make_multi_index(i0, i1));
// hardcoding for buffer_load
// TODO refactor transfer_data() to encapsulate this
static_assert(SrcAddressSpace == AddressSpace::Global,
"wrong! hardcoded to use buffer_load, src must be global mem");
buffer_(Number<buffer_offset>{}) = amd_buffer_load<SrcData, 1>(
p_src,
src_slice_origin_.GetOffset(),
coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_),
src_desc.GetElementSpaceSize());
// move dim1 iterator
if constexpr(iter1.value < Len1 - 1)
{
if constexpr(forward_dim1)
{
move_dynamic_tensor_coordinate(
src_desc, src_slice_origin_, src_step_0_p1);
}
else
{
move_dynamic_tensor_coordinate(
src_desc, src_slice_origin_, src_step_0_m1);
}
}
});
// move dim0 iterator
if constexpr(iter0.value < Len0 - 1)
{
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_, src_step_p1_0);
}
});
}
// move src and dst coordinate back to their origins
if constexpr(SrcResetCoordinateAfterRun)
{
const auto src_back_step =
make_dynamic_tensor_coordinate_step(src_desc, GetCoordinateBackStep());
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_, src_back_step);
}
}
__device__ void RunWrite(const DstDesc& dst_desc, DstData* p_dst)
{
static_assert(remove_reference_t<DstDesc>::GetNumOfDimension() == 2,
"wrong! hardcoded for 2D tensor");
// hardcoded for 2D
// TODO implement N-D
if constexpr(remove_reference_t<SrcDesc>::GetNumOfDimension() == 2)
{
// TODO use constexpr for coordinate-step to make sure compiler behave correctly
const auto dst_step_0_p1 =
make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(0, 1));
const auto dst_step_0_m1 =
make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(0, -1));
const auto dst_step_p1_0 =
make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(1, 0));
const auto dst_step_m1_0 =
make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(-1, 0));
constexpr index_t Len0 = SliceLengths{}[0];
constexpr index_t Len1 = SliceLengths{}[1];
static_for<0, Len0, 1>{}([&](auto iter0) {
static_for<0, Len1, 1>{}([&](auto iter1) {
// step direction
constexpr bool forward_dim1 = (iter0.value % 2 == 0);
constexpr index_t i0 = iter0;
constexpr index_t i1 = forward_dim1 ? iter1.value : Len1 - iter1.value - 1;
// do work
constexpr index_t buffer_offset =
buffer_desc_.CalculateOffset(make_multi_index(i0, i1));
// hardcoding for ds_write
// TODO refactor transfer_data() to encapsulate this
static_assert(DstAddressSpace == AddressSpace::Lds &&
DstInMemOp == InMemoryDataOperation::Set,
"wrong! hardcoded for ds_write");
p_dst[dst_slice_origin_.GetOffset()] = buffer_[Number<buffer_offset>{}];
// move dim1 iterator
if constexpr(iter1.value < Len1 - 1)
{
if constexpr(forward_dim1)
{
move_dynamic_tensor_coordinate(
dst_desc, dst_slice_origin_, dst_step_0_p1);
}
else
{
move_dynamic_tensor_coordinate(
dst_desc, dst_slice_origin_, dst_step_0_m1);
}
}
});
// move dim0 iterator
if constexpr(iter0.value < Len0 - 1)
{
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_, dst_step_p1_0);
}
});
}
if constexpr(DstResetCoordinateAfterRun)
{
const auto dst_back_step =
make_dynamic_tensor_coordinate_step(dst_desc, GetCoordinateBackStep());
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_, dst_back_step);
}
}
__device__ static constexpr auto GetCoordinateBackStep()
{
MultiIndex<nDim> back_step;
back_step(Number<0>{}) = 1 - SliceLengths{}[0];
static_for<1, nDim, 1>{}([&](auto i) {
back_step(i) = (SliceLengths{}[i - Number<1>{}] % 2 == 0) ? 0 : (1 - SliceLengths{}[i]);
});
return back_step;
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
// is it OK to construct a new step every time?
const auto src_slice_origin_step =
make_dynamic_tensor_coordinate_step(src_desc, src_slice_origin_step_idx);
const auto adjusted_step_idx = SrcResetCoordinateAfterRun
? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetCoordinateBackStep();
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_, src_slice_origin_step);
const auto adjusted_step = make_dynamic_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_, adjusted_step);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
......@@ -599,13 +622,23 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
const Index& dst_slice_origin_step_idx)
{
// is it OK to construct a new step every time?
const auto dst_slice_origin_step =
make_dynamic_tensor_coordinate_step(dst_desc, dst_slice_origin_step_idx);
const auto adjusted_step_idx = DstResetCoordinateAfterRun
? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetCoordinateBackStep();
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_, dst_slice_origin_step);
const auto adjusted_step = make_dynamic_tensor_coordinate_step(dst_desc, adjusted_step_idx);
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_, adjusted_step);
}
private:
static constexpr auto buffer_desc_ =
make_dynamic_naive_tensor_descriptor_packed<nDim>(to_multi_index(SliceLengths{}));
static constexpr index_t buffer_size_ = buffer_desc_.GetElementSpaceSize();
StaticallyIndexedArray<SrcData, buffer_size_> buffer_;
SrcCoord src_slice_origin_;
DstCoord dst_slice_origin_;
};
......
......@@ -51,26 +51,24 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor(
}
template <class InDesc, class WeiDesc, class OutDesc>
constexpr std::size_t calculate_convolution_flops(InDesc, WeiDesc, OutDesc)
constexpr std::size_t
calculate_convolution_flops(const InDesc& in_desc, const WeiDesc& wei_desc, const OutDesc& out_desc)
{
using namespace ck;
constexpr auto wei_desc = WeiDesc{};
constexpr auto out_desc = OutDesc{};
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr index_t N = out_desc.GetLength(I0);
constexpr index_t K = out_desc.GetLength(I1);
constexpr index_t Ho = out_desc.GetLength(I2);
constexpr index_t Wo = out_desc.GetLength(I3);
const index_t N = out_desc.GetLength(I0);
const index_t K = out_desc.GetLength(I1);
const index_t Ho = out_desc.GetLength(I2);
const index_t Wo = out_desc.GetLength(I3);
constexpr index_t C = wei_desc.GetLength(I1);
constexpr index_t Y = wei_desc.GetLength(I2);
constexpr index_t X = wei_desc.GetLength(I3);
const index_t C = wei_desc.GetLength(I1);
const index_t Y = wei_desc.GetLength(I2);
const index_t X = wei_desc.GetLength(I3);
return std::size_t(2) * N * K * Ho * Wo * C * Y * X;
}
......
......@@ -577,7 +577,7 @@ int main(int argc, char* argv[])
LeftPads{},
RightPads{},
nrepeat);
#elif 0
#elif 1
device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment