Commit bd22abb5 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent e131f6aa
......@@ -52,12 +52,14 @@ struct BlockwiseGenericTensorSliceCopy_v4
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
// map threads to cluster
constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
static_assert(BlockSize >= mThreadClusterDesc.GetElementSize(),
"wrong! BlockSize too small");
if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
const auto thread_cluster_id =
thread_cluster_desc.CalculateClusterIndex(get_thread_local_1d_id());
mThreadClusterDesc.CalculateClusterIndex(get_thread_local_1d_id());
const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{};
......@@ -67,6 +69,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
mThreadwiseStore.SetSrcSliceOrigin(make_zero_array<index_t, nDim>());
mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin);
}
}
__device__ static constexpr index_t GetThreadBufferSize()
{
......@@ -80,22 +83,8 @@ struct BlockwiseGenericTensorSliceCopy_v4
constexpr bool has_optimized_address_calculation =
decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation();
constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
if(BlockSize == thread_cluster_desc.GetElementSize())
{
// TODO: threadwise copy is still being tweaked
if(has_optimized_address_calculation)
{
mThreadwiseLoad.Run_optimized_src_address_calculation(p_block_src, p_thread_buffer);
}
else
{
mThreadwiseLoad.Run(p_block_src, p_thread_buffer);
}
}
else if(get_thread_local_1d_id() < thread_cluster_desc.GetElementSize())
if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
// TODO: threadwise copy is still being tweaked
if(has_optimized_address_calculation)
......@@ -116,23 +105,8 @@ struct BlockwiseGenericTensorSliceCopy_v4
constexpr bool has_optimized_address_calculation =
decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation();
constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
if(BlockSize == thread_cluster_desc.GetElementSize())
{
// TODO: threadwise copy is still being tweaked
if(has_optimized_address_calculation)
{
mThreadwiseStore.Run_optimized_dst_address_calculation(p_thread_buffer,
p_block_dst);
}
else
{
mThreadwiseStore.Run(p_thread_buffer, p_block_dst);
}
}
else if(get_thread_local_1d_id() < thread_cluster_desc.GetElementSize())
if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
// TODO: threadwise copy is still being tweaked
if(has_optimized_address_calculation)
......@@ -158,27 +132,39 @@ struct BlockwiseGenericTensorSliceCopy_v4
BlockSrcData p_thread_buffer[GetThreadBufferSize()];
if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
RunLoadThreadBuffer(p_block_src, p_thread_buffer);
// if there is type conversion, it's done during store
RunStoreThreadBuffer(p_thread_buffer, p_block_dst);
}
}
template <typename T, bool PositiveDirection>
__device__ void
MoveSrcSliceWindow(const T& step_sizes,
integral_constant<bool, PositiveDirection> positive_direction)
{
if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
mThreadwiseLoad.MoveSrcSliceWindow(step_sizes, positive_direction);
}
}
template <typename T, bool PositiveDirection>
__device__ void
MoveDstSliceWindow(const T& step_sizes,
integral_constant<bool, PositiveDirection> positive_direction)
{
if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
mThreadwiseStore.MoveDstSliceWindow(step_sizes, positive_direction);
}
}
private:
using ThreadBufferDesc = decltype(make_native_tensor_descriptor_packed(ThreadSliceLengths{}));
......@@ -205,6 +191,9 @@ struct BlockwiseGenericTensorSliceCopy_v4
DstAddressSpace,
DstInMemOp>;
static constexpr auto mThreadClusterDesc =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
ThreadwiseLoad mThreadwiseLoad;
ThreadwiseStore mThreadwiseStore;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment