Commit 70d06fa9 authored by Chao Liu's avatar Chao Liu
Browse files

fixing useless instruction issue

parent 7733dd88
...@@ -173,66 +173,57 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -173,66 +173,57 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// GEMM // GEMM
#if 1
using gridwise_gemm = using gridwise_gemm =
GridwiseDynamicGemm_km_kn_mn_v1r2<BlockSize, GridwiseDynamicGemm_km_kn_mn_v1<BlockSize,
Float, Float,
AccFloat, AccFloat,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
GemmMPerThread, GemmMPerThread,
GemmNPerThread, GemmNPerThread,
GemmKPerThread, GemmKPerThread,
GemmMLevel0Cluster, GemmMLevel0Cluster,
GemmNLevel0Cluster, GemmNLevel0Cluster,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM, GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM, GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>, Sequence<1, 0>,
Sequence<1, 0>, Sequence<1, 0>,
0, 0,
GemmABlockTransferSrcScalarPerVector_GemmK, GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM, GemmABlockTransferDstScalarPerVector_GemmM,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>, Sequence<0, 1>,
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
GemmBBlockTransferSrcScalarPerVector_GemmN, GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN, GemmBBlockTransferDstScalarPerVector_GemmN,
Sequence<2, 3, 0, 1>, Sequence<2, 3, 0, 1>,
3, 3,
GemmCThreadTransferDstScalarPerVector_GemmN1>; GemmCThreadTransferDstScalarPerVector_GemmN1>;
const index_t GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock); const index_t GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const bool is_even_number_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0; const bool is_even_number_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0;
const auto kernel_even =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, true>>;
const auto kernel_odd =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, false>>;
if(is_even_number_k_block_loop) if(is_even_number_k_block_loop)
{ {
launch_kernel(kernel_even, const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -247,7 +238,17 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -247,7 +238,17 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
} }
else else
{ {
launch_kernel(kernel_odd, const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -260,6 +261,63 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -260,6 +261,63 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
p_out_global, p_out_global,
integral_constant<bool, false>{}); integral_constant<bool, false>{});
} }
#else
using gridwise_gemm =
GridwiseDynamicGemm_km_kn_mn_v2<BlockSize,
Float,
AccFloat,
InMemoryDataOperation::Set,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1>;
const index_t GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global);
#endif
} }
}; };
......
...@@ -42,7 +42,7 @@ template <index_t BlockSize, ...@@ -42,7 +42,7 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector> index_t CThreadTransferDstScalarPerVector>
struct GridwiseDynamicGemm_km_kn_mn_v1r1 struct GridwiseDynamicGemm_km_kn_mn_v1
{ {
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
...@@ -90,11 +90,24 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1 ...@@ -90,11 +90,24 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
const index_t N = b_k_n_global_desc.GetLength(I1); const index_t N = b_k_n_global_desc.GetLength(I1);
// divide block work by [M, N] // divide block work by [M, N]
#if 0
const index_t m_block_work_num = M / MPerBlock; const index_t m_block_work_num = M / MPerBlock;
const index_t n_block_work_num = N / NPerBlock; const index_t n_block_work_num = N / NPerBlock;
#else
// Hack: this force result into SGPR
const index_t m_block_work_num = __builtin_amdgcn_readfirstlane(M / MPerBlock);
const index_t n_block_work_num = __builtin_amdgcn_readfirstlane(N / NPerBlock);
#endif
#if 0
const index_t m_block_work_id = get_block_1d_id() / n_block_work_num; const index_t m_block_work_id = get_block_1d_id() / n_block_work_num;
const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num; const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num;
#else
// Hack: this force result into SGPR
const index_t m_block_work_id =
__builtin_amdgcn_readfirstlane(get_block_1d_id() / n_block_work_num);
const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num;
#endif
const index_t m_block_data_on_global = m_block_work_id * MPerBlock; const index_t m_block_data_on_global = m_block_work_id * MPerBlock;
const index_t n_block_data_on_global = n_block_work_id * NPerBlock; const index_t n_block_data_on_global = n_block_work_id * NPerBlock;
...@@ -117,7 +130,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1 ...@@ -117,7 +130,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
// A matrix blockwise copy // A matrix blockwise copy
auto a_block_copy = auto a_block_copy =
BlockwiseDynamicTensorSliceTransfer_v2r2<BlockSize, BlockwiseDynamicTensorSliceTransfer_v2r3<BlockSize,
Float, Float,
Float, Float,
decltype(a_k_m_global_desc), decltype(a_k_m_global_desc),
...@@ -136,14 +149,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1 ...@@ -136,14 +149,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
AddressSpace::Lds, AddressSpace::Lds,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
1, 1,
1>(a_k_m_global_desc, 1,
make_multi_index(0, m_block_data_on_global), true,
a_k_m_block_desc, true>(
make_multi_index(0, 0)); a_k_m_global_desc,
make_multi_index(0, m_block_data_on_global),
a_k_m_block_desc,
make_multi_index(0, 0));
// B matrix blockwise copy // B matrix blockwise copy
auto b_block_copy = auto b_block_copy =
BlockwiseDynamicTensorSliceTransfer_v2r2<BlockSize, BlockwiseDynamicTensorSliceTransfer_v2r3<BlockSize,
Float, Float,
Float, Float,
decltype(b_k_n_global_desc), decltype(b_k_n_global_desc),
...@@ -162,10 +178,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1 ...@@ -162,10 +178,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
AddressSpace::Lds, AddressSpace::Lds,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
1, 1,
1>(b_k_n_global_desc, 1,
make_multi_index(0, n_block_data_on_global), #if 0
b_k_n_block_desc, true.
make_multi_index(0, 0)); #else
false,
#endif
true>(
b_k_n_global_desc,
make_multi_index(0, n_block_data_on_global),
b_k_n_block_desc,
make_multi_index(0, 0));
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -230,12 +253,25 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1 ...@@ -230,12 +253,25 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread); threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
#if 0
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
#else
// HACK: fuse threadwise copy move-back coordinate with move src slice window
constexpr auto b_block_slice_copy_step =
b_block_copy.threadwise_read_.GetCoordinateStepBack() + make_multi_index(KPerBlock, 0);
#endif
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_block_copy.Run(a_k_m_global_desc, p_a_global, a_k_m_block_desc, p_a_block_double); Float p_a_thread_buffer[a_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
b_block_copy.Run(b_k_n_global_desc, p_b_global, b_k_n_block_desc, p_b_block_double); Float p_b_thread_buffer[b_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
a_block_copy.RunRead(a_k_m_global_desc, p_a_global, p_a_thread_buffer);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global, p_b_thread_buffer);
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_double, p_a_thread_buffer);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_double, p_b_thread_buffer);
} }
// LDS double buffer: main body // LDS double buffer: main body
...@@ -262,16 +298,19 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1 ...@@ -262,16 +298,19 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
__syncthreads(); __syncthreads();
Float p_a_thread_buffer[a_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
Float p_b_thread_buffer[b_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_block_copy.RunRead(a_k_m_global_desc, p_a_global); a_block_copy.RunRead(a_k_m_global_desc, p_a_global, p_a_thread_buffer);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global); b_block_copy.RunRead(b_k_n_global_desc, p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
block_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread); block_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_next); a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_next, p_a_thread_buffer);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_next); b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_next, p_b_thread_buffer);
} }
} }
...@@ -284,16 +323,21 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1 ...@@ -284,16 +323,21 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
__syncthreads(); __syncthreads();
Float p_a_thread_buffer[a_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
Float p_b_thread_buffer[b_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_block_copy.RunRead(a_k_m_global_desc, p_a_global); a_block_copy.RunRead(a_k_m_global_desc, p_a_global, p_a_thread_buffer);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global); b_block_copy.RunRead(b_k_n_global_desc, p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
block_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); block_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size); a_block_copy.RunWrite(
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_double + b_block_space_size); a_k_m_block_desc, p_a_block_double + a_block_space_size, p_a_thread_buffer);
b_block_copy.RunWrite(
b_k_n_block_desc, p_b_block_double + b_block_space_size, p_b_thread_buffer);
__syncthreads(); __syncthreads();
...@@ -411,7 +455,7 @@ template <index_t BlockSize, ...@@ -411,7 +455,7 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector> index_t CThreadTransferDstScalarPerVector>
struct GridwiseDynamicGemm_km_kn_mn_v1r2 struct GridwiseDynamicGemm_km_kn_mn_v2
{ {
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
...@@ -437,18 +481,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2 ...@@ -437,18 +481,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
constexpr index_t b_block_space_size = constexpr index_t b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float); return (a_block_space_size + b_block_space_size) * sizeof(Float);
} }
template <typename... ADesc, typename... BDesc, typename... CDesc, bool IsEvenNumberKBlockLoop> template <typename... ADesc, typename... BDesc, typename... CDesc>
__device__ void Run(const DynamicTensorDescriptor<ADesc...>& a_k_m_global_desc, __device__ void Run(const DynamicTensorDescriptor<ADesc...>& a_k_m_global_desc,
const Float* __restrict__ p_a_global, const Float* __restrict__ p_a_global,
const DynamicTensorDescriptor<BDesc...>& b_k_n_global_desc, const DynamicTensorDescriptor<BDesc...>& b_k_n_global_desc,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
const DynamicTensorDescriptor<CDesc...>& c_m0_m1_n0_n1_global_desc, const DynamicTensorDescriptor<CDesc...>& c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global, Float* __restrict__ p_c_global,
Float* __restrict__ p_shared_block, Float* __restrict__ p_shared_block) const
integral_constant<bool, IsEvenNumberKBlockLoop>) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -612,8 +655,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2 ...@@ -612,8 +655,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
constexpr index_t b_block_space_size = constexpr index_t b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
Float* p_a_block_double = p_shared_block; Float* p_a_block = p_shared_block;
Float* p_b_block_double = p_shared_block + 2 * a_block_space_size; Float* p_b_block = p_shared_block + a_block_space_size;
// register allocation for output // register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()]; AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
...@@ -631,7 +674,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2 ...@@ -631,7 +674,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
b_block_copy.threadwise_read_.GetCoordinateStepBack() + make_multi_index(KPerBlock, 0); b_block_copy.threadwise_read_.GetCoordinateStepBack() + make_multi_index(KPerBlock, 0);
#endif #endif
// LDS double buffer: preload data into LDS // preload data into LDS
{ {
Float p_a_thread_buffer[a_block_copy.thread_buffer_desc_.GetElementSpaceSize()]; Float p_a_thread_buffer[a_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
Float p_b_thread_buffer[b_block_copy.thread_buffer_desc_.GetElementSpaceSize()]; Float p_b_thread_buffer[b_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
...@@ -639,89 +682,41 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2 ...@@ -639,89 +682,41 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
a_block_copy.RunRead(a_k_m_global_desc, p_a_global, p_a_thread_buffer); a_block_copy.RunRead(a_k_m_global_desc, p_a_global, p_a_thread_buffer);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global, p_b_thread_buffer); b_block_copy.RunRead(b_k_n_global_desc, p_b_global, p_b_thread_buffer);
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_double, p_a_thread_buffer); a_block_copy.RunWrite(a_k_m_block_desc, p_a_block, p_a_thread_buffer);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_double, p_b_thread_buffer); b_block_copy.RunWrite(b_k_n_block_desc, p_b_block, p_b_thread_buffer);
} }
// LDS double buffer: main body // main body
for(index_t k_block_data_begin = 0; k_block_data_begin < K - 2 * KPerBlock; for(index_t k_block_data_begin = 0; k_block_data_begin < K - KPerBlock;
k_block_data_begin += 2 * KPerBlock) k_block_data_begin += KPerBlock)
{ {
#pragma unroll Float p_a_thread_buffer[a_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
for(index_t iloop = 0; iloop < 2; ++iloop) Float p_b_thread_buffer[b_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
{
const bool even_loop = (iloop % 2 == 0);
Float* p_a_block_now =
even_loop ? p_a_block_double : p_a_block_double + a_block_space_size;
Float* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space_size;
Float* p_a_block_next =
even_loop ? p_a_block_double + a_block_space_size : p_a_block_double;
Float* p_b_block_next =
even_loop ? p_b_block_double + b_block_space_size : p_b_block_double;
a_block_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step); a_block_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_block_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step); b_block_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
__syncthreads(); // load next data from device mem
a_block_copy.RunRead(a_k_m_global_desc, p_a_global, p_a_thread_buffer);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global, p_b_thread_buffer);
Float p_a_thread_buffer[a_block_copy.thread_buffer_desc_.GetElementSpaceSize()]; __syncthreads();
Float p_b_thread_buffer[b_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
// LDS doubel buffer: load next data from device mem // GEMM on current data
a_block_copy.RunRead(a_k_m_global_desc, p_a_global, p_a_thread_buffer); block_gemm.Run(p_a_block, p_b_block, p_c_thread);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on current data __syncthreads();
block_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
// LDS double buffer: store next data to LDS // store next data to LDS
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_next, p_a_thread_buffer); a_block_copy.RunWrite(a_k_m_block_desc, p_a_block, p_a_thread_buffer);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_next, p_b_thread_buffer); b_block_copy.RunWrite(b_k_n_block_desc, p_b_block, p_b_thread_buffer);
}
} }
// LDS double buffer: tail // tail
{ {
if constexpr(IsEvenNumberKBlockLoop) // if has 2 iteration left __syncthreads();
{
a_block_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_block_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
__syncthreads(); block_gemm.Run(p_a_block, p_b_block, p_c_thread);
Float p_a_thread_buffer[a_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
Float p_b_thread_buffer[b_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
// LDS double buffer: load last data from device mem
a_block_copy.RunRead(a_k_m_global_desc, p_a_global, p_a_thread_buffer);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
block_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS
a_block_copy.RunWrite(
a_k_m_block_desc, p_a_block_double + a_block_space_size, p_a_thread_buffer);
b_block_copy.RunWrite(
b_k_n_block_desc, p_b_block_double + b_block_space_size, p_b_thread_buffer);
__syncthreads();
// LDS double buffer: GEMM on last data
block_gemm.Run(p_a_block_double + a_block_space_size,
p_b_block_double + b_block_space_size,
p_c_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
block_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
}
} }
// output: register to global memory // output: register to global memory
...@@ -769,14 +764,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2 ...@@ -769,14 +764,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
} }
} }
template <typename... ADesc, typename... BDesc, typename... CDesc, bool IsEvenNumberKBlockLoop> template <typename... ADesc, typename... BDesc, typename... CDesc>
__device__ void Run(const DynamicTensorDescriptor<ADesc...>& a_k_m_global_desc, __device__ void Run(const DynamicTensorDescriptor<ADesc...>& a_k_m_global_desc,
const Float* __restrict__ p_a_global, const Float* __restrict__ p_a_global,
const DynamicTensorDescriptor<BDesc...>& b_k_n_global_desc, const DynamicTensorDescriptor<BDesc...>& b_k_n_global_desc,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
const DynamicTensorDescriptor<CDesc...>& c_m0_m1_n0_n1_global_desc, const DynamicTensorDescriptor<CDesc...>& c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global, Float* __restrict__ p_c_global) const
integral_constant<bool, IsEvenNumberKBlockLoop>) const
{ {
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float); constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
...@@ -788,8 +782,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2 ...@@ -788,8 +782,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
p_b_global, p_b_global,
c_m0_m1_n0_n1_global_desc, c_m0_m1_n0_n1_global_desc,
p_c_global, p_c_global,
p_shared_block, p_shared_block);
integral_constant<bool, IsEvenNumberKBlockLoop>{});
} }
}; };
......
...@@ -255,16 +255,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2 ...@@ -255,16 +255,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
constexpr index_t Len0 = SliceLengths{}[0]; constexpr index_t Len0 = SliceLengths{}[0];
constexpr index_t Len1 = SliceLengths{}[1]; constexpr index_t Len1 = SliceLengths{}[1];
bool forward_dim0 = true;
bool forward_dim1 = true;
#pragma unroll #pragma unroll
for(index_t i0 = 0; i0 < Len0; ++i0) for(index_t i0 = 0; i0 < Len0; ++i0)
{ {
#pragma unroll #pragma unroll
for(index_t i1 = 0; i1 < Len1; ++i1) for(index_t i1 = 0; i1 < Len1; ++i1)
{ {
// do work #if 1 // debug
// do work
transfer_data<SrcData, transfer_data<SrcData,
1, 1,
SrcAddressSpace, SrcAddressSpace,
...@@ -282,10 +280,69 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2 ...@@ -282,10 +280,69 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
coordinate_has_valid_offset_assuming_visible_index_is_valid( coordinate_has_valid_offset_assuming_visible_index_is_valid(
dst_desc, dst_slice_origin_), dst_desc, dst_slice_origin_),
dst_desc.GetElementSpaceSize()); dst_desc.GetElementSpaceSize());
#else
if constexpr(SrcAddressSpace == AddressSpace::Global &&
DstAddressSpace == AddressSpace::Vgpr)
{
if(coordinate_has_valid_offset_assuming_visible_index_is_valid(
dst_desc, dst_slice_origin_))
{
const SrcData tmp = amd_buffer_load<SrcData, 1>(
p_src,
src_slice_origin_.GetOffset(),
coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_),
src_desc.GetElementSpaceSize());
const index_t dst_offset = dst_slice_origin_.GetOffset();
p_dst[dst_offset] = tmp;
}
}
else if constexpr(SrcAddressSpace == AddressSpace::Vgpr &&
DstAddressSpace == AddressSpace::Global)
{
const SrcData zeros = 0;
const bool src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_);
const bool dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(
dst_desc, dst_slice_origin_);
amd_buffer_store<SrcData, 1>(
src_valid ? &(p_src[src_slice_origin_.GetOffset()]) : &zeros,
p_dst,
dst_slice_origin_.GetOffset(),
dst_valid,
dst_desc.GetElementSpaceSize());
}
else
{
if(coordinate_has_valid_offset_assuming_visible_index_is_valid(
dst_desc, dst_slice_origin_))
{
if(coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_))
{
p_dst[dst_slice_origin_.GetOffset()] =
p_src[src_slice_origin_.GetOffset()];
}
else
{
p_dst[dst_slice_origin_.GetOffset()] = 0;
}
}
}
#endif
// move dim1 iterator // move dim1 iterator
if(i1 < Len1 - 1) if(i1 < Len1 - 1)
{ {
bool forward_dim1 = (i0 % 2 == 0);
if(forward_dim1) if(forward_dim1)
{ {
move_dynamic_tensor_coordinate( move_dynamic_tensor_coordinate(
...@@ -303,22 +360,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2 ...@@ -303,22 +360,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
} }
} }
// switch dim1 iteration direction
forward_dim1 = !forward_dim1;
// move dim0 iterator // move dim0 iterator
if(i0 < Len0 - 1) if(i0 < Len0 - 1)
{ {
if(forward_dim0) move_dynamic_tensor_coordinate(src_desc, src_slice_origin_, src_step_p1_0);
{ move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_, dst_step_p1_0);
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_, src_step_p1_0);
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_, dst_step_p1_0);
}
else
{
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_, src_step_m1_0);
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_, dst_step_m1_0);
}
} }
} }
} }
......
...@@ -185,25 +185,26 @@ __device__ void transfer_data(const T* p_src, ...@@ -185,25 +185,26 @@ __device__ void transfer_data(const T* p_src,
"wrong! InMemoryDataOperation not supported!"); "wrong! InMemoryDataOperation not supported!");
// keep it simple, don't use static_if here, otherwise compiler will do weird things // keep it simple, don't use static_if here, otherwise compiler will do weird things
if(SrcDataStride == 1 && DstDataStride == 1) if constexpr(SrcDataStride == 1 && DstDataStride == 1)
{ {
// TODO: use static_if::ElseIf if constexpr(DstInMemOp == InMemoryDataOperation::Set)
static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) { {
SetData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>( SetData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src, src_offset, src_valid, src_range, p_dst, dst_offset, dst_valid, dst_range); p_src, src_offset, src_valid, src_range, p_dst, dst_offset, dst_valid, dst_range);
}); }
else if constexpr(DstInMemOp == InMemoryDataOperation::AtomicAdd)
static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto) { {
AtomicAddData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>( AtomicAddData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src, src_offset, src_valid, src_range, p_dst, dst_offset, dst_valid, dst_range); p_src, src_offset, src_valid, src_range, p_dst, dst_offset, dst_valid, dst_range);
}); }
} }
else else
{ {
#pragma unroll
for(index_t i = 0; i < DataPerAccess; ++i) for(index_t i = 0; i < DataPerAccess; ++i)
{ {
// TODO: use static_if::ElseIf if constexpr(DstInMemOp == InMemoryDataOperation::Set)
static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) { {
SetData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>( SetData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src, p_src,
src_offset + i * SrcDataStride, src_offset + i * SrcDataStride,
...@@ -213,9 +214,9 @@ __device__ void transfer_data(const T* p_src, ...@@ -213,9 +214,9 @@ __device__ void transfer_data(const T* p_src,
dst_offset + i * DstDataStride, dst_offset + i * DstDataStride,
dst_valid, dst_valid,
dst_range); dst_range);
}); }
else if constexpr(DstInMemOp == InMemoryDataOperation::AtomicAdd)
static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto) { {
AtomicAddData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>( AtomicAddData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src, p_src,
src_offset + i * SrcDataStride, src_offset + i * SrcDataStride,
...@@ -225,7 +226,7 @@ __device__ void transfer_data(const T* p_src, ...@@ -225,7 +226,7 @@ __device__ void transfer_data(const T* p_src,
dst_offset + i * DstDataStride, dst_offset + i * DstDataStride,
dst_valid, dst_valid,
dst_range); dst_range);
}); }
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment