Commit 3b3cfae5 authored by Chao Liu's avatar Chao Liu
Browse files

add blockwise copy that doesn't has thread buffer as member to avoid alloca...

add blockwise copy that doesn't has thread buffer as member to avoid alloca and therefore scratch mem
parent 54138dc8
......@@ -174,37 +174,37 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
// GEMM
using gridwise_gemm =
GridwiseDynamicGemm_km_kn_mn_v1<BlockSize,
Float,
AccFloat,
InMemoryDataOperation::Set,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1>;
GridwiseDynamicGemm_km_kn_mn_v1r2<BlockSize,
Float,
AccFloat,
InMemoryDataOperation::Set,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1>;
const index_t GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
......
......@@ -121,7 +121,7 @@ struct BlockwiseDynamicTensorSliceTransfer_v1r1
ThreadwiseTransfer threadwise_transfer_;
};
// this version has scratch memory issue, due to:
// this version is very likely to have scratch memory issue, due to:
// 1. ThreadwiseDynamicTensorSliceTransfer_v1r1 keeps reference to tensor descriptor
// 2. threadwise_dynamic_tensor_slice_transfer_v1r1 constructs new tensor coordinate
template <index_t BlockSize,
......@@ -287,7 +287,7 @@ struct BlockwiseDynamicTensorSliceTransfer_v2r1
BlockSrcData p_thread_buffer_[thread_buffer_element_size_];
};
// this version does not have scratch memory issue, due to:
// this version does following things to avoid scratch memory issue
// 1. ThreadwiseDynamicTensorSliceTransfer_v1r2 does not keep reference to tensor descriptor
// 2. threadwise_dynamic_tensor_slice_transfer_v1r2 does not construct new tensor coordinate
template <index_t BlockSize,
......@@ -462,5 +462,169 @@ struct BlockwiseDynamicTensorSliceTransfer_v2r2
BlockSrcData p_thread_buffer_[thread_buffer_element_size_];
};
// this version does following things to avoid scratch memory issue
// 1. BlockwiseDynamicTensorSliceTransfer_v2r3 doesn't allocate thread buffer (array) as member
// 2. ThreadwiseDynamicTensorSliceTransfer_v1r2 does not keep reference to tensor descriptor
// 3. threadwise_dynamic_tensor_slice_transfer_v1r2 does not construct new tensor coordinate
template <index_t BlockSize,
typename BlockSrcData,
typename BlockDstData,
typename BlockSrcDesc,
typename BlockDstDesc,
typename BlockSliceLengths,
typename ThreadSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorReadDim,
index_t DstVectorWriteDim,
index_t SrcDataPerRead,
index_t DstDataPerWrite,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
InMemoryDataOperation DstInMemOp,
index_t SrcDataStride,
index_t DstDataStride>
struct BlockwiseDynamicTensorSliceTransfer_v2r3
{
static constexpr index_t nDim =
remove_reference_t<remove_cv_t<BlockSrcDesc>>::GetNumOfDimension();
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseDynamicTensorSliceTransfer_v2r3(
const BlockSrcDesc& block_src_desc,
const Index& src_block_slice_origin,
const BlockDstDesc& block_dst_desc,
const Index& dst_block_slice_origin)
: threadwise_read_(block_src_desc,
make_zero_multi_index<nDim>(),
thread_buffer_desc_,
make_zero_multi_index<nDim>()),
threadwise_write_(thread_buffer_desc_,
make_zero_multi_index<nDim>(),
block_dst_desc,
make_zero_multi_index<nDim>())
{
static_assert(
nDim == remove_reference_t<remove_cv_t<BlockSrcDesc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<BlockDstDesc>>::GetNumOfDimension() &&
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
nDim == ThreadClusterLengths::Size() && nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_id =
thread_cluster_desc_.CalculateClusterIndex(get_thread_local_1d_id());
const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{};
threadwise_read_.SetSrcSliceOrigin(block_src_desc,
src_block_slice_origin + thread_data_id_begin);
threadwise_read_.SetDstSliceOrigin(thread_buffer_desc_, make_zero_multi_index<nDim>());
threadwise_write_.SetSrcSliceOrigin(thread_buffer_desc_, make_zero_multi_index<nDim>());
threadwise_write_.SetDstSliceOrigin(block_dst_desc,
dst_block_slice_origin + thread_data_id_begin);
}
}
__device__ static constexpr auto CalculateThreadDataBegin()
{
const auto thread_cluster_id =
thread_cluster_desc_.CalculateClusterIndex(get_thread_local_1d_id());
return thread_cluster_id * ThreadSliceLengths{};
}
__device__ void RunRead(const BlockSrcDesc& block_src_desc,
const BlockSrcData* p_block_src,
BlockSrcData* p_thread_buffer)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_read_.Run(block_src_desc, p_block_src, thread_buffer_desc_, p_thread_buffer);
}
}
__device__ void RunWrite(const BlockDstDesc& block_dst_desc,
BlockDstData* p_block_dst,
BlockDstData* p_thread_buffer)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_write_.Run(
thread_buffer_desc_, p_thread_buffer, block_dst_desc, p_block_dst);
}
}
__device__ void MoveSrcSliceWindow(const BlockSrcDesc& block_src_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_read_.MoveSrcSliceWindow(block_src_desc, step);
}
}
__device__ void MoveDstSliceWindow(const BlockDstDesc& block_dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_write_.MoveDstSliceWindow(block_dst_desc, step);
}
}
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
static constexpr auto thread_buffer_desc_ =
make_dynamic_naive_tensor_descriptor_packed<nDim>(to_multi_index(ThreadSliceLengths{}));
using ThreadwiseRead = ThreadwiseDynamicTensorSliceTransfer_v1r2<BlockSrcDesc,
decltype(thread_buffer_desc_),
ThreadSliceLengths,
SrcDimAccessOrder,
SrcVectorReadDim,
SrcDataPerRead,
1,
SrcAddressSpace,
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
SrcDataStride,
1>;
using ThreadwiseWrite = ThreadwiseDynamicTensorSliceTransfer_v1r2<decltype(thread_buffer_desc_),
BlockDstDesc,
ThreadSliceLengths,
DstDimAccessOrder,
DstVectorWriteDim,
1,
DstDataPerWrite,
AddressSpace::Vgpr,
DstAddressSpace,
DstInMemOp,
1,
DstDataStride>;
ThreadwiseRead threadwise_read_;
ThreadwiseWrite threadwise_write_;
};
} // namespace ck
#endif
......@@ -42,7 +42,7 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
struct GridwiseDynamicGemm_km_kn_mn_v1
struct GridwiseDynamicGemm_km_kn_mn_v1r1
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
......@@ -114,6 +114,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned<2>(
make_multi_index(KPerBlock, NPerBlock), max_lds_align);
// A matrix blockwise copy
auto a_block_copy =
BlockwiseDynamicTensorSliceTransfer_v2r2<BlockSize,
......@@ -379,5 +380,388 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
}
};
template <index_t BlockSize,
typename Float,
typename AccFloat,
InMemoryDataOperation CGlobalMemoryDataOperation,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M,
typename BBlockTransferThreadSliceLengths_K_N,
typename BBlockTransferThreadClusterLengths_K_N,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
struct GridwiseDynamicGemm_km_kn_mn_v1r2
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr index_t max_lds_align = math::lcm(ABlockTransferDstScalarPerVector_M,
BBlockTransferDstScalarPerVector_N,
MPerThread,
NPerThread);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned<2>(
make_multi_index(KPerBlock, MPerBlock), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned<2>(
make_multi_index(KPerBlock, NPerBlock), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr index_t b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float);
}
template <typename... ADesc, typename... BDesc, typename... CDesc, bool IsEvenNumberKBlockLoop>
__device__ void Run(const DynamicTensorDescriptor<ADesc...>& a_k_m_global_desc,
const Float* __restrict__ p_a_global,
const DynamicTensorDescriptor<BDesc...>& b_k_n_global_desc,
const Float* __restrict__ p_b_global,
const DynamicTensorDescriptor<CDesc...>& c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global,
Float* __restrict__ p_shared_block,
integral_constant<bool, IsEvenNumberKBlockLoop>) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
const index_t K = a_k_m_global_desc.GetLength(I0);
const index_t M = a_k_m_global_desc.GetLength(I1);
const index_t N = b_k_n_global_desc.GetLength(I1);
// divide block work by [M, N]
const index_t m_block_work_num = M / MPerBlock;
const index_t n_block_work_num = N / NPerBlock;
const index_t m_block_work_id = get_block_1d_id() / n_block_work_num;
const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num;
const index_t m_block_data_on_global = m_block_work_id * MPerBlock;
const index_t n_block_data_on_global = n_block_work_id * NPerBlock;
// lds max alignment
constexpr index_t max_lds_align = math::lcm(ABlockTransferDstScalarPerVector_M,
BBlockTransferDstScalarPerVector_N,
MPerThread,
NPerThread);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned<2>(
make_multi_index(KPerBlock, MPerBlock), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned<2>(
make_multi_index(KPerBlock, NPerBlock), max_lds_align);
// A matrix blockwise copy
auto a_block_copy =
BlockwiseDynamicTensorSliceTransfer_v2r3<BlockSize,
Float,
Float,
decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc),
Sequence<KPerBlock, MPerBlock>,
ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
Sequence<0, 1>,
ABlockTransferSrcVectorDim,
1,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M,
AddressSpace::Global,
AddressSpace::Lds,
InMemoryDataOperation::Set,
1,
1>(a_k_m_global_desc,
make_multi_index(0, m_block_data_on_global),
a_k_m_block_desc,
make_multi_index(0, 0));
// B matrix blockwise copy
auto b_block_copy =
BlockwiseDynamicTensorSliceTransfer_v2r3<BlockSize,
Float,
Float,
decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc),
Sequence<KPerBlock, NPerBlock>,
BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
Sequence<0, 1>,
BBlockTransferSrcVectorDim,
1,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N,
AddressSpace::Global,
AddressSpace::Lds,
InMemoryDataOperation::Set,
1,
1>(b_k_n_global_desc,
make_multi_index(0, n_block_data_on_global),
b_k_n_block_desc,
make_multi_index(0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr index_t a_k_m_block_mtx_stride =
a_k_m_block_desc.CalculateOffset(make_multi_index(1, 0)) -
a_k_m_block_desc.CalculateOffset(make_multi_index(0, 0));
constexpr index_t b_k_n_block_mtx_stride =
b_k_n_block_desc.CalculateOffset(make_multi_index(1, 0)) -
b_k_n_block_desc.CalculateOffset(make_multi_index(0, 0));
constexpr auto a_k_m_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<KPerBlock>{}, Number<MPerBlock>{}, Number<a_k_m_block_mtx_stride>{});
constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<KPerBlock>{}, Number<NPerBlock>{}, Number<b_k_n_block_mtx_stride>{});
// sanity check
static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
"wrong!");
constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
constexpr index_t NRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<MRepeat * MPerThread>{}, Number<NRepeat * NPerThread>{});
const auto block_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc),
decltype(c_m0m1_n0n1_thread_mtx_desc),
MPerThread,
NPerThread,
KPerThread,
MLevel0Cluster,
NLevel0Cluster,
MLevel1Cluster,
NLevel1Cluster,
MPerThread,
NPerThread>{};
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr index_t b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
Float* p_a_block_double = p_shared_block;
Float* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
// LDS double buffer: preload data into LDS
{
Float p_a_thread_buffer[a_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
Float p_b_thread_buffer[b_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
a_block_copy.RunRead(a_k_m_global_desc, p_a_global, p_a_thread_buffer);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global, p_b_thread_buffer);
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_double, p_a_thread_buffer);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_double, p_b_thread_buffer);
}
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin < K - 2 * KPerBlock;
k_block_data_begin += 2 * KPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_a_block_now =
even_loop ? p_a_block_double : p_a_block_double + a_block_space_size;
Float* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space_size;
Float* p_a_block_next =
even_loop ? p_a_block_double + a_block_space_size : p_a_block_double;
Float* p_b_block_next =
even_loop ? p_b_block_double + b_block_space_size : p_b_block_double;
a_block_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_block_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
__syncthreads();
Float p_a_thread_buffer[a_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
Float p_b_thread_buffer[b_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
// LDS doubel buffer: load next data from device mem
a_block_copy.RunRead(a_k_m_global_desc, p_a_global, p_a_thread_buffer);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on current data
block_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
// LDS double buffer: store next data to LDS
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_next, p_a_thread_buffer);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_next, p_b_thread_buffer);
}
}
// LDS double buffer: tail
{
if constexpr(IsEvenNumberKBlockLoop) // if has 2 iteration left
{
a_block_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_block_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
__syncthreads();
Float p_a_thread_buffer[a_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
Float p_b_thread_buffer[b_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
// LDS double buffer: load last data from device mem
a_block_copy.RunRead(a_k_m_global_desc, p_a_global, p_a_thread_buffer);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
block_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS
a_block_copy.RunWrite(
a_k_m_block_desc, p_a_block_double + a_block_space_size, p_a_thread_buffer);
b_block_copy.RunWrite(
b_k_n_block_desc, p_b_block_double + b_block_space_size, p_b_thread_buffer);
__syncthreads();
// LDS double buffer: GEMM on last data
block_gemm.Run(p_a_block_double + a_block_space_size,
p_b_block_double + b_block_space_size,
p_c_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
block_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
}
}
// output: register to global memory
{
constexpr index_t M1 = MPerThread * MLevel0Cluster * MLevel1Cluster;
constexpr index_t N1 = NPerThread * NLevel0Cluster * NLevel1Cluster;
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr auto c_m0_m1_n0_n1_thread_desc =
make_dynamic_naive_tensor_descriptor_packed<4>(
make_multi_index(MRepeat, MPerThread, NRepeat, NPerThread));
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
block_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row;
const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseDynamicTensorSliceTransfer_v1r2<
decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc),
Sequence<MRepeat, MPerThread, NRepeat, NPerThread>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
1,
CThreadTransferDstScalarPerVector,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation,
1,
1>(c_m0_m1_n0_n1_thread_desc,
make_multi_index(0, 0, 0, 0),
c_m0_m1_n0_n1_global_desc,
make_multi_index(m_thread_data_on_global / M1,
m_thread_data_on_global % M1,
n_thread_data_on_global / N1,
n_thread_data_on_global % N1))
.Run(c_m0_m1_n0_n1_thread_desc, p_c_thread, c_m0_m1_n0_n1_global_desc, p_c_global);
}
}
template <typename... ADesc, typename... BDesc, typename... CDesc, bool IsEvenNumberKBlockLoop>
__device__ void Run(const DynamicTensorDescriptor<ADesc...>& a_k_m_global_desc,
const Float* __restrict__ p_a_global,
const DynamicTensorDescriptor<BDesc...>& b_k_n_global_desc,
const Float* __restrict__ p_b_global,
const DynamicTensorDescriptor<CDesc...>& c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global,
integral_constant<bool, IsEvenNumberKBlockLoop>) const
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
__shared__ Float p_shared_block[shared_block_size];
Run(a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
p_shared_block,
integral_constant<bool, IsEvenNumberKBlockLoop>{});
}
};
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment