Commit 70d06fa9 authored by Chao Liu's avatar Chao Liu
Browse files

fixing useless instruction issue

parent 7733dd88
......@@ -173,66 +173,57 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// GEMM
#if 1
using gridwise_gemm =
GridwiseDynamicGemm_km_kn_mn_v1r2<BlockSize,
Float,
AccFloat,
InMemoryDataOperation::Set,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1>;
GridwiseDynamicGemm_km_kn_mn_v1<BlockSize,
Float,
AccFloat,
InMemoryDataOperation::Set,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1>;
const index_t GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const bool is_even_number_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0;
const auto kernel_even =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, true>>;
const auto kernel_odd =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, false>>;
if(is_even_number_k_block_loop)
{
launch_kernel(kernel_even,
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
......@@ -247,7 +238,17 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
}
else
{
launch_kernel(kernel_odd,
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
......@@ -260,6 +261,63 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
p_out_global,
integral_constant<bool, false>{});
}
#else
using gridwise_gemm =
GridwiseDynamicGemm_km_kn_mn_v2<BlockSize,
Float,
AccFloat,
InMemoryDataOperation::Set,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1>;
const index_t GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global);
#endif
}
};
......
......@@ -42,7 +42,7 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
struct GridwiseDynamicGemm_km_kn_mn_v1r1
struct GridwiseDynamicGemm_km_kn_mn_v1
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
......@@ -90,11 +90,24 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
const index_t N = b_k_n_global_desc.GetLength(I1);
// divide block work by [M, N]
#if 0
const index_t m_block_work_num = M / MPerBlock;
const index_t n_block_work_num = N / NPerBlock;
#else
// Hack: this force result into SGPR
const index_t m_block_work_num = __builtin_amdgcn_readfirstlane(M / MPerBlock);
const index_t n_block_work_num = __builtin_amdgcn_readfirstlane(N / NPerBlock);
#endif
#if 0
const index_t m_block_work_id = get_block_1d_id() / n_block_work_num;
const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num;
#else
// Hack: this force result into SGPR
const index_t m_block_work_id =
__builtin_amdgcn_readfirstlane(get_block_1d_id() / n_block_work_num);
const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num;
#endif
const index_t m_block_data_on_global = m_block_work_id * MPerBlock;
const index_t n_block_data_on_global = n_block_work_id * NPerBlock;
......@@ -117,7 +130,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
// A matrix blockwise copy
auto a_block_copy =
BlockwiseDynamicTensorSliceTransfer_v2r2<BlockSize,
BlockwiseDynamicTensorSliceTransfer_v2r3<BlockSize,
Float,
Float,
decltype(a_k_m_global_desc),
......@@ -136,14 +149,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
AddressSpace::Lds,
InMemoryDataOperation::Set,
1,
1>(a_k_m_global_desc,
make_multi_index(0, m_block_data_on_global),
a_k_m_block_desc,
make_multi_index(0, 0));
1,
true,
true>(
a_k_m_global_desc,
make_multi_index(0, m_block_data_on_global),
a_k_m_block_desc,
make_multi_index(0, 0));
// B matrix blockwise copy
auto b_block_copy =
BlockwiseDynamicTensorSliceTransfer_v2r2<BlockSize,
BlockwiseDynamicTensorSliceTransfer_v2r3<BlockSize,
Float,
Float,
decltype(b_k_n_global_desc),
......@@ -162,10 +178,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
AddressSpace::Lds,
InMemoryDataOperation::Set,
1,
1>(b_k_n_global_desc,
make_multi_index(0, n_block_data_on_global),
b_k_n_block_desc,
make_multi_index(0, 0));
1,
#if 0
true.
#else
false,
#endif
true>(
b_k_n_global_desc,
make_multi_index(0, n_block_data_on_global),
b_k_n_block_desc,
make_multi_index(0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
......@@ -230,12 +253,25 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
#if 0
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
#else
// HACK: fuse threadwise copy move-back coordinate with move src slice window
constexpr auto b_block_slice_copy_step =
b_block_copy.threadwise_read_.GetCoordinateStepBack() + make_multi_index(KPerBlock, 0);
#endif
// LDS double buffer: preload data into LDS
{
a_block_copy.Run(a_k_m_global_desc, p_a_global, a_k_m_block_desc, p_a_block_double);
b_block_copy.Run(b_k_n_global_desc, p_b_global, b_k_n_block_desc, p_b_block_double);
Float p_a_thread_buffer[a_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
Float p_b_thread_buffer[b_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
a_block_copy.RunRead(a_k_m_global_desc, p_a_global, p_a_thread_buffer);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global, p_b_thread_buffer);
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_double, p_a_thread_buffer);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_double, p_b_thread_buffer);
}
// LDS double buffer: main body
......@@ -262,16 +298,19 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
__syncthreads();
Float p_a_thread_buffer[a_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
Float p_b_thread_buffer[b_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
// LDS doubel buffer: load next data from device mem
a_block_copy.RunRead(a_k_m_global_desc, p_a_global);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global);
a_block_copy.RunRead(a_k_m_global_desc, p_a_global, p_a_thread_buffer);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on current data
block_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
// LDS double buffer: store next data to LDS
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_next);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_next);
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_next, p_a_thread_buffer);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_next, p_b_thread_buffer);
}
}
......@@ -284,16 +323,21 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r1
__syncthreads();
Float p_a_thread_buffer[a_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
Float p_b_thread_buffer[b_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
// LDS double buffer: load last data from device mem
a_block_copy.RunRead(a_k_m_global_desc, p_a_global);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global);
a_block_copy.RunRead(a_k_m_global_desc, p_a_global, p_a_thread_buffer);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
block_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_double + b_block_space_size);
a_block_copy.RunWrite(
a_k_m_block_desc, p_a_block_double + a_block_space_size, p_a_thread_buffer);
b_block_copy.RunWrite(
b_k_n_block_desc, p_b_block_double + b_block_space_size, p_b_thread_buffer);
__syncthreads();
......@@ -411,7 +455,7 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
struct GridwiseDynamicGemm_km_kn_mn_v1r2
struct GridwiseDynamicGemm_km_kn_mn_v2
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
......@@ -437,18 +481,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
constexpr index_t b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float);
return (a_block_space_size + b_block_space_size) * sizeof(Float);
}
template <typename... ADesc, typename... BDesc, typename... CDesc, bool IsEvenNumberKBlockLoop>
template <typename... ADesc, typename... BDesc, typename... CDesc>
__device__ void Run(const DynamicTensorDescriptor<ADesc...>& a_k_m_global_desc,
const Float* __restrict__ p_a_global,
const DynamicTensorDescriptor<BDesc...>& b_k_n_global_desc,
const Float* __restrict__ p_b_global,
const DynamicTensorDescriptor<CDesc...>& c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global,
Float* __restrict__ p_shared_block,
integral_constant<bool, IsEvenNumberKBlockLoop>) const
Float* __restrict__ p_shared_block) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -612,8 +655,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
constexpr index_t b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
Float* p_a_block_double = p_shared_block;
Float* p_b_block_double = p_shared_block + 2 * a_block_space_size;
Float* p_a_block = p_shared_block;
Float* p_b_block = p_shared_block + a_block_space_size;
// register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
......@@ -631,7 +674,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
b_block_copy.threadwise_read_.GetCoordinateStepBack() + make_multi_index(KPerBlock, 0);
#endif
// LDS double buffer: preload data into LDS
// preload data into LDS
{
Float p_a_thread_buffer[a_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
Float p_b_thread_buffer[b_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
......@@ -639,89 +682,41 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
a_block_copy.RunRead(a_k_m_global_desc, p_a_global, p_a_thread_buffer);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global, p_b_thread_buffer);
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_double, p_a_thread_buffer);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_double, p_b_thread_buffer);
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block, p_a_thread_buffer);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block, p_b_thread_buffer);
}
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin < K - 2 * KPerBlock;
k_block_data_begin += 2 * KPerBlock)
// main body
for(index_t k_block_data_begin = 0; k_block_data_begin < K - KPerBlock;
k_block_data_begin += KPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_a_block_now =
even_loop ? p_a_block_double : p_a_block_double + a_block_space_size;
Float* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space_size;
Float* p_a_block_next =
even_loop ? p_a_block_double + a_block_space_size : p_a_block_double;
Float* p_b_block_next =
even_loop ? p_b_block_double + b_block_space_size : p_b_block_double;
Float p_a_thread_buffer[a_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
Float p_b_thread_buffer[b_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
a_block_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_block_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
a_block_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_block_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
__syncthreads();
// load next data from device mem
a_block_copy.RunRead(a_k_m_global_desc, p_a_global, p_a_thread_buffer);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global, p_b_thread_buffer);
Float p_a_thread_buffer[a_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
Float p_b_thread_buffer[b_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_block_copy.RunRead(a_k_m_global_desc, p_a_global, p_a_thread_buffer);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global, p_b_thread_buffer);
// GEMM on current data
block_gemm.Run(p_a_block, p_b_block, p_c_thread);
// LDS double buffer: GEMM on current data
block_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
__syncthreads();
// LDS double buffer: store next data to LDS
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_next, p_a_thread_buffer);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_next, p_b_thread_buffer);
}
// store next data to LDS
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block, p_a_thread_buffer);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block, p_b_thread_buffer);
}
// LDS double buffer: tail
// tail
{
if constexpr(IsEvenNumberKBlockLoop) // if has 2 iteration left
{
a_block_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_block_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
__syncthreads();
__syncthreads();
Float p_a_thread_buffer[a_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
Float p_b_thread_buffer[b_block_copy.thread_buffer_desc_.GetElementSpaceSize()];
// LDS double buffer: load last data from device mem
a_block_copy.RunRead(a_k_m_global_desc, p_a_global, p_a_thread_buffer);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
block_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS
a_block_copy.RunWrite(
a_k_m_block_desc, p_a_block_double + a_block_space_size, p_a_thread_buffer);
b_block_copy.RunWrite(
b_k_n_block_desc, p_b_block_double + b_block_space_size, p_b_thread_buffer);
__syncthreads();
// LDS double buffer: GEMM on last data
block_gemm.Run(p_a_block_double + a_block_space_size,
p_b_block_double + b_block_space_size,
p_c_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
block_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
}
block_gemm.Run(p_a_block, p_b_block, p_c_thread);
}
// output: register to global memory
......@@ -769,14 +764,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
}
}
template <typename... ADesc, typename... BDesc, typename... CDesc, bool IsEvenNumberKBlockLoop>
template <typename... ADesc, typename... BDesc, typename... CDesc>
__device__ void Run(const DynamicTensorDescriptor<ADesc...>& a_k_m_global_desc,
const Float* __restrict__ p_a_global,
const DynamicTensorDescriptor<BDesc...>& b_k_n_global_desc,
const Float* __restrict__ p_b_global,
const DynamicTensorDescriptor<CDesc...>& c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global,
integral_constant<bool, IsEvenNumberKBlockLoop>) const
Float* __restrict__ p_c_global) const
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
......@@ -788,8 +782,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
p_shared_block,
integral_constant<bool, IsEvenNumberKBlockLoop>{});
p_shared_block);
}
};
......
......@@ -255,16 +255,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
constexpr index_t Len0 = SliceLengths{}[0];
constexpr index_t Len1 = SliceLengths{}[1];
bool forward_dim0 = true;
bool forward_dim1 = true;
#pragma unroll
for(index_t i0 = 0; i0 < Len0; ++i0)
{
#pragma unroll
for(index_t i1 = 0; i1 < Len1; ++i1)
{
// do work
#if 1 // debug
// do work
transfer_data<SrcData,
1,
SrcAddressSpace,
......@@ -282,10 +280,69 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
coordinate_has_valid_offset_assuming_visible_index_is_valid(
dst_desc, dst_slice_origin_),
dst_desc.GetElementSpaceSize());
#else
if constexpr(SrcAddressSpace == AddressSpace::Global &&
DstAddressSpace == AddressSpace::Vgpr)
{
if(coordinate_has_valid_offset_assuming_visible_index_is_valid(
dst_desc, dst_slice_origin_))
{
const SrcData tmp = amd_buffer_load<SrcData, 1>(
p_src,
src_slice_origin_.GetOffset(),
coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_),
src_desc.GetElementSpaceSize());
const index_t dst_offset = dst_slice_origin_.GetOffset();
p_dst[dst_offset] = tmp;
}
}
else if constexpr(SrcAddressSpace == AddressSpace::Vgpr &&
DstAddressSpace == AddressSpace::Global)
{
const SrcData zeros = 0;
const bool src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_);
const bool dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(
dst_desc, dst_slice_origin_);
amd_buffer_store<SrcData, 1>(
src_valid ? &(p_src[src_slice_origin_.GetOffset()]) : &zeros,
p_dst,
dst_slice_origin_.GetOffset(),
dst_valid,
dst_desc.GetElementSpaceSize());
}
else
{
if(coordinate_has_valid_offset_assuming_visible_index_is_valid(
dst_desc, dst_slice_origin_))
{
if(coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_))
{
p_dst[dst_slice_origin_.GetOffset()] =
p_src[src_slice_origin_.GetOffset()];
}
else
{
p_dst[dst_slice_origin_.GetOffset()] = 0;
}
}
}
#endif
// move dim1 iterator
if(i1 < Len1 - 1)
{
bool forward_dim1 = (i0 % 2 == 0);
if(forward_dim1)
{
move_dynamic_tensor_coordinate(
......@@ -303,22 +360,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
}
}
// switch dim1 iteration direction
forward_dim1 = !forward_dim1;
// move dim0 iterator
if(i0 < Len0 - 1)
{
if(forward_dim0)
{
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_, src_step_p1_0);
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_, dst_step_p1_0);
}
else
{
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_, src_step_m1_0);
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_, dst_step_m1_0);
}
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_, src_step_p1_0);
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_, dst_step_p1_0);
}
}
}
......
......@@ -185,25 +185,26 @@ __device__ void transfer_data(const T* p_src,
"wrong! InMemoryDataOperation not supported!");
// keep it simple, don't use static_if here, otherwise compiler will do weird things
if(SrcDataStride == 1 && DstDataStride == 1)
if constexpr(SrcDataStride == 1 && DstDataStride == 1)
{
// TODO: use static_if::ElseIf
static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) {
if constexpr(DstInMemOp == InMemoryDataOperation::Set)
{
SetData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src, src_offset, src_valid, src_range, p_dst, dst_offset, dst_valid, dst_range);
});
static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto) {
}
else if constexpr(DstInMemOp == InMemoryDataOperation::AtomicAdd)
{
AtomicAddData<T, DataPerAccess>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src, src_offset, src_valid, src_range, p_dst, dst_offset, dst_valid, dst_range);
});
}
}
else
{
#pragma unroll
for(index_t i = 0; i < DataPerAccess; ++i)
{
// TODO: use static_if::ElseIf
static_if<DstInMemOp == InMemoryDataOperation::Set>{}([&](auto) {
if constexpr(DstInMemOp == InMemoryDataOperation::Set)
{
SetData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src,
src_offset + i * SrcDataStride,
......@@ -213,9 +214,9 @@ __device__ void transfer_data(const T* p_src,
dst_offset + i * DstDataStride,
dst_valid,
dst_range);
});
static_if<DstInMemOp == InMemoryDataOperation::AtomicAdd>{}([&](auto) {
}
else if constexpr(DstInMemOp == InMemoryDataOperation::AtomicAdd)
{
AtomicAddData<T, 1>{}.template Run<SrcAddressSpace, DstAddressSpace>(
p_src,
src_offset + i * SrcDataStride,
......@@ -225,7 +226,7 @@ __device__ void transfer_data(const T* p_src,
dst_offset + i * DstDataStride,
dst_valid,
dst_range);
});
}
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment