Commit 77c81617 authored by Chao Liu's avatar Chao Liu
Browse files

improving index calculation

parent f2f35201
...@@ -34,7 +34,7 @@ template <index_t BlockSize, ...@@ -34,7 +34,7 @@ template <index_t BlockSize,
index_t GemmBBlockTransferSrcScalarPerVector_GemmN, index_t GemmBBlockTransferSrcScalarPerVector_GemmN,
index_t GemmBBlockTransferDstScalarPerVector_GemmN, index_t GemmBBlockTransferDstScalarPerVector_GemmN,
index_t GemmCThreadTransferDstScalarPerVector_GemmN1> index_t GemmCThreadTransferDstScalarPerVector_GemmN1>
struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
{ {
template <typename... Wei, typename... In, typename... Out> template <typename... Wei, typename... In, typename... Out>
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc, __host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
...@@ -96,18 +96,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -96,18 +96,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
// input tensor // input tensor
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor( const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor(
transform_dynamic_tensor_descriptor( in_n_c_hi_wi_global_desc,
in_n_c_hi_wi_global_desc,
make_tuple(DynamicPassThrough{N},
DynamicPassThrough{C},
DynamicLeftPad{Hi, InLeftPadH},
DynamicLeftPad{Wi, InLeftPadW}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})),
make_tuple(DynamicPassThrough{N}, make_tuple(DynamicPassThrough{N},
DynamicPassThrough{C}, DynamicPassThrough{C},
DynamicRightPad{Hi + InLeftPadH, InRightPadH}, DynamicPad{Hi, InLeftPadH, InRightPadH},
DynamicRightPad{Wi + InLeftPadW, InRightPadW}), DynamicPad{Wi, InLeftPadW, InRightPadW}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
...@@ -164,6 +157,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -164,6 +157,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
const index_t GemmM0 = GemmM / GemmM1; const index_t GemmM0 = GemmM / GemmM1;
const index_t GemmN0 = GemmN / GemmN1; const index_t GemmN0 = GemmN / GemmN1;
#if 1 // debug
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc =
transform_dynamic_tensor_descriptor( transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc, out_gemmm_gemmn_global_desc,
...@@ -171,6 +165,16 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -171,6 +165,16 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
DynamicUnMerge<2>{make_multi_index(GemmN0, GemmN1)}), DynamicUnMerge<2>{make_multi_index(GemmN0, GemmN1)}),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
#else
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc =
transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc,
make_tuple(
HackSemiDynamicUnMerge<3, Sequence<GemmM1>>{make_multi_index(1, GemmM0)},
HackSemiDynamicUnMerge<3, Sequence<GemmN1>>{make_multi_index(1, GemmN0)}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
#endif
// GEMM // GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v1< using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v1<
......
...@@ -96,6 +96,15 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 ...@@ -96,6 +96,15 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
} }
} }
__device__ void RunRead_hack(const SrcDesc& src_desc, const SrcData* p_src)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead_hack(src_desc, p_src);
}
}
__device__ void RunWrite(const DstDesc& dst_desc, DstData* p_dst) __device__ void RunWrite(const DstDesc& dst_desc, DstData* p_dst)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
...@@ -114,6 +123,15 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 ...@@ -114,6 +123,15 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
} }
} }
__device__ void MoveSrcSliceWindow_hack(const SrcDesc& src_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow_hack(src_desc, step);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{ {
if(BlockSize == thread_cluster_desc_.GetElementSize() or if(BlockSize == thread_cluster_desc_.GetElementSize() or
...@@ -149,146 +167,5 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 ...@@ -149,146 +167,5 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
ThreadwiseTransfer threadwise_transfer_; ThreadwiseTransfer threadwise_transfer_;
}; };
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize,
InMemoryDataOperation DstInMemOp,
typename BlockSliceLengths,
typename ThreadSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector,
index_t ThreadTransferSrcResetCoordinateAfterRun,
index_t ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseDynamicTensorSliceTransfer_v4_hack
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseDynamicTensorSliceTransfer_v4_hack(
const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin)
: threadwise_transfer_(
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
{
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_id =
thread_cluster_desc_.CalculateClusterIndex(get_thread_local_1d_id());
const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{};
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_id_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_id_begin);
}
}
__device__ static constexpr auto CalculateThreadDataBegin()
{
const auto thread_cluster_id =
thread_cluster_desc_.CalculateClusterIndex(get_thread_local_1d_id());
return thread_cluster_id * ThreadSliceLengths{};
}
__device__ void RunRead(const SrcDesc& src_desc, const SrcData* p_src)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_desc, p_src);
}
}
__device__ void RunWrite(const DstDesc& dst_desc, DstData* p_dst)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunWrite(dst_desc, p_dst);
}
}
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseDynamicTensorSliceTransfer_v3_hack<ThreadSliceLengths,
DstInMemOp,
SrcData,
DstData,
SrcDesc,
DstDesc,
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVector,
DstScalarPerVector,
SrcScalarStrideInVector,
DstScalarStrideInVector,
SrcAddressSpace,
DstAddressSpace,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck } // namespace ck
#endif #endif
...@@ -166,28 +166,28 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -166,28 +166,28 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4_hack<BlockSize, BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
Sequence<KPerBlock, NPerBlock>, Sequence<KPerBlock, NPerBlock>,
BBlockTransferThreadSliceLengths_K_N, BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N, BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
Float, Float,
Float, Float,
decltype(b_k_n_global_desc), decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc), decltype(b_k_n_block_desc),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<0, 1>, Sequence<0, 1>,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
1, 1,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N, BBlockTransferDstScalarPerVector_N,
AddressSpace::Global, AddressSpace::Global,
AddressSpace::Lds, AddressSpace::Lds,
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
b_k_n_global_desc, b_k_n_global_desc,
make_multi_index(0, n_block_data_on_global), make_multi_index(0, n_block_data_on_global),
b_k_n_block_desc, b_k_n_block_desc,
...@@ -258,16 +258,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -258,16 +258,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
#if 1
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global); a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global); b_blockwise_copy.RunRead_hack(b_k_n_global_desc, p_b_global);
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double); a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double); b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double);
} }
#endif
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
...@@ -285,13 +283,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -285,13 +283,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
{ {
// even iteration // even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step); b_blockwise_copy.MoveSrcSliceWindow_hack(b_k_n_global_desc,
b_block_slice_copy_step);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global); a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global); b_blockwise_copy.RunRead_hack(b_k_n_global_desc, p_b_global);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread); blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread);
...@@ -302,13 +301,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -302,13 +301,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// odd iteration // odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step); b_blockwise_copy.MoveSrcSliceWindow_hack(b_k_n_global_desc,
b_block_slice_copy_step);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global); a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global); b_blockwise_copy.RunRead_hack(b_k_n_global_desc, p_b_global);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread); blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread);
...@@ -326,13 +326,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -326,13 +326,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step); b_blockwise_copy.MoveSrcSliceWindow_hack(b_k_n_global_desc, b_block_slice_copy_step);
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global); a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global); b_blockwise_copy.RunRead_hack(b_k_n_global_desc, p_b_global);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
...@@ -384,8 +384,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -384,8 +384,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
Float, Float,
decltype(c_m0_m1_n0_n1_thread_desc), decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc), decltype(c_m0_m1_n0_n1_global_desc),
#if 1 // debug
Sequence<MRepeat, MPerThread, NRepeat, NPerThread>, Sequence<MRepeat, MPerThread, NRepeat, NPerThread>,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
#else
Sequence<1, 1, 2, 4>,
Sequence<0, 1, 2, 3>,
#endif
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
1, 1,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
...@@ -402,7 +407,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -402,7 +407,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
m_thread_data_on_global % M1, m_thread_data_on_global % M1,
n_thread_data_on_global / N1, n_thread_data_on_global / N1,
n_thread_data_on_global % N1)) n_thread_data_on_global % N1))
.Run(c_m0_m1_n0_n1_thread_desc, p_c_thread, c_m0_m1_n0_n1_global_desc, p_c_global); .Run_hack(
c_m0_m1_n0_n1_thread_desc, p_c_thread, c_m0_m1_n0_n1_global_desc, p_c_global);
} }
} }
......
...@@ -152,10 +152,15 @@ __device__ float amd_buffer_load<float, 1>(const float* p_src_wave, ...@@ -152,10 +152,15 @@ __device__ float amd_buffer_load<float, 1>(const float* p_src_wave,
return __llvm_amdgcn_buffer_load_f32( return __llvm_amdgcn_buffer_load_f32(
src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false); src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false);
#else #else
#if 1 // debug
float tmp = __llvm_amdgcn_buffer_load_f32( float tmp = __llvm_amdgcn_buffer_load_f32(
src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false); src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false);
return src_thread_data_valid ? tmp : float(0); return src_thread_data_valid ? tmp : float(0);
#else
return __llvm_amdgcn_buffer_load_f32(
src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false);
#endif
#endif #endif
} }
......
...@@ -87,7 +87,7 @@ ...@@ -87,7 +87,7 @@
// thread-invariant, otherwise it's a bug // thread-invariant, otherwise it's a bug
// TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread" // TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread"
#ifndef CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE #ifndef CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
#define CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0 #define CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 1
#endif #endif
// workaround: put all workaround here // workaround: put all workaround here
......
...@@ -750,6 +750,13 @@ __host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, ...@@ -750,6 +750,13 @@ __host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce,
return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::type{}; return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::type{};
} }
template <typename Seq, typename Reduce, index_t Init>
__host__ __device__ constexpr auto reverse_exclusive_scan_sequence(Seq, Reduce, Number<Init>)
{
return reverse_inclusive_scan_sequence(Seq::PopFront(), Reduce{}, Number<Init>{})
.PushBack(Number<Init>{});
}
template <typename Seq, typename Reduce, index_t Init> template <typename Seq, typename Reduce, index_t Init>
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<Init>) __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<Init>)
{ {
......
...@@ -155,6 +155,7 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -155,6 +155,7 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1 #elif 1
// cdata = 64, BlockSize = 256, 128x128x8 // cdata = 64, BlockSize = 256, 128x128x8
// b threadwise copy 4x1
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
...@@ -185,6 +186,40 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -185,6 +186,40 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x8
// b threadwise copy 2x2
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 2>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1 #elif 1
// cdata = 64, BlockSize = 256, 128x128x8 // cdata = 64, BlockSize = 256, 128x128x8
......
...@@ -41,12 +41,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc ...@@ -41,12 +41,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
const auto in_n_c_hi_wi_desc = make_dynamic_naive_tensor_descriptor<4>( // assume packed tensor
to_multi_index(InDesc::GetLengths()), to_multi_index(InDesc::GetStrides())); const auto in_n_c_hi_wi_desc =
const auto wei_k_c_y_x_desc = make_dynamic_naive_tensor_descriptor<4>( make_dynamic_naive_tensor_descriptor_packed<4>(to_multi_index(InDesc::GetLengths()));
to_multi_index(WeiDesc::GetLengths()), to_multi_index(WeiDesc::GetStrides())); const auto wei_k_c_y_x_desc =
const auto out_n_k_ho_wo_desc = make_dynamic_naive_tensor_descriptor<4>( make_dynamic_naive_tensor_descriptor_packed<4>(to_multi_index(WeiDesc::GetLengths()));
to_multi_index(OutDesc::GetLengths()), to_multi_index(OutDesc::GetStrides())); const auto out_n_k_ho_wo_desc =
make_dynamic_naive_tensor_descriptor_packed<4>(to_multi_index(OutDesc::GetLengths()));
const auto conv_strides = to_multi_index(ConvStrides{}); const auto conv_strides = to_multi_index(ConvStrides{});
const auto conv_dilations = to_multi_index(ConvDilations{}); const auto conv_dilations = to_multi_index(ConvDilations{});
...@@ -115,6 +116,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc ...@@ -115,6 +116,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 1 #elif 1
// cdata = 64, BlockSize = 256, 128x128x8 // cdata = 64, BlockSize = 256, 128x128x8
// b thread copy 4x1
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
...@@ -142,6 +144,37 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc ...@@ -142,6 +144,37 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x8
// b thread copy 2x2
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#endif #endif
...@@ -169,7 +202,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc ...@@ -169,7 +202,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
constexpr auto conv_driver = constexpr auto conv_driver =
#if 1 // debug #if 1 // debug
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
#else #else
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
#endif #endif
......
...@@ -217,7 +217,7 @@ int main(int argc, char* argv[]) ...@@ -217,7 +217,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 1 #elif 0
// 3x3, 35x35, stride 2 // 3x3, 35x35, stride 2
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 288; constexpr index_t C = 288;
...@@ -352,7 +352,7 @@ int main(int argc, char* argv[]) ...@@ -352,7 +352,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
// 3x3, 28x28 // 3x3, 28x28
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment