"src/include/Sequence.hpp" did not exist on "766b0a9eafe29a5d2a75c350345e54165ceaf405"
Commit bd22abb5 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent e131f6aa
...@@ -52,12 +52,14 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -52,12 +52,14 @@ struct BlockwiseGenericTensorSliceCopy_v4
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{}, is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window"); "wrong! threads should be mapped to cover entire slicing window");
// map threads to cluster static_assert(BlockSize >= mThreadClusterDesc.GetElementSize(),
constexpr auto thread_cluster_desc = "wrong! BlockSize too small");
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
const auto thread_cluster_id = const auto thread_cluster_id =
thread_cluster_desc.CalculateClusterIndex(get_thread_local_1d_id()); mThreadClusterDesc.CalculateClusterIndex(get_thread_local_1d_id());
const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{}; const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{};
...@@ -67,6 +69,7 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -67,6 +69,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
mThreadwiseStore.SetSrcSliceOrigin(make_zero_array<index_t, nDim>()); mThreadwiseStore.SetSrcSliceOrigin(make_zero_array<index_t, nDim>());
mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin); mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin);
} }
}
__device__ static constexpr index_t GetThreadBufferSize() __device__ static constexpr index_t GetThreadBufferSize()
{ {
...@@ -80,22 +83,8 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -80,22 +83,8 @@ struct BlockwiseGenericTensorSliceCopy_v4
constexpr bool has_optimized_address_calculation = constexpr bool has_optimized_address_calculation =
decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation(); decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation();
constexpr auto thread_cluster_desc = if(BlockSize == mThreadClusterDesc.GetElementSize() or
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
if(BlockSize == thread_cluster_desc.GetElementSize())
{
// TODO: threadwise copy is still being tweaked
if(has_optimized_address_calculation)
{
mThreadwiseLoad.Run_optimized_src_address_calculation(p_block_src, p_thread_buffer);
}
else
{
mThreadwiseLoad.Run(p_block_src, p_thread_buffer);
}
}
else if(get_thread_local_1d_id() < thread_cluster_desc.GetElementSize())
{ {
// TODO: threadwise copy is still being tweaked // TODO: threadwise copy is still being tweaked
if(has_optimized_address_calculation) if(has_optimized_address_calculation)
...@@ -116,23 +105,8 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -116,23 +105,8 @@ struct BlockwiseGenericTensorSliceCopy_v4
constexpr bool has_optimized_address_calculation = constexpr bool has_optimized_address_calculation =
decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation(); decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation();
constexpr auto thread_cluster_desc = if(BlockSize == mThreadClusterDesc.GetElementSize() or
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
if(BlockSize == thread_cluster_desc.GetElementSize())
{
// TODO: threadwise copy is still being tweaked
if(has_optimized_address_calculation)
{
mThreadwiseStore.Run_optimized_dst_address_calculation(p_thread_buffer,
p_block_dst);
}
else
{
mThreadwiseStore.Run(p_thread_buffer, p_block_dst);
}
}
else if(get_thread_local_1d_id() < thread_cluster_desc.GetElementSize())
{ {
// TODO: threadwise copy is still being tweaked // TODO: threadwise copy is still being tweaked
if(has_optimized_address_calculation) if(has_optimized_address_calculation)
...@@ -158,27 +132,39 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -158,27 +132,39 @@ struct BlockwiseGenericTensorSliceCopy_v4
BlockSrcData p_thread_buffer[GetThreadBufferSize()]; BlockSrcData p_thread_buffer[GetThreadBufferSize()];
if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
RunLoadThreadBuffer(p_block_src, p_thread_buffer); RunLoadThreadBuffer(p_block_src, p_thread_buffer);
// if there is type conversion, it's done during store // if there is type conversion, it's done during store
RunStoreThreadBuffer(p_thread_buffer, p_block_dst); RunStoreThreadBuffer(p_thread_buffer, p_block_dst);
} }
}
template <typename T, bool PositiveDirection> template <typename T, bool PositiveDirection>
__device__ void __device__ void
MoveSrcSliceWindow(const T& step_sizes, MoveSrcSliceWindow(const T& step_sizes,
integral_constant<bool, PositiveDirection> positive_direction) integral_constant<bool, PositiveDirection> positive_direction)
{
if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{ {
mThreadwiseLoad.MoveSrcSliceWindow(step_sizes, positive_direction); mThreadwiseLoad.MoveSrcSliceWindow(step_sizes, positive_direction);
} }
}
template <typename T, bool PositiveDirection> template <typename T, bool PositiveDirection>
__device__ void __device__ void
MoveDstSliceWindow(const T& step_sizes, MoveDstSliceWindow(const T& step_sizes,
integral_constant<bool, PositiveDirection> positive_direction) integral_constant<bool, PositiveDirection> positive_direction)
{
if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{ {
mThreadwiseStore.MoveDstSliceWindow(step_sizes, positive_direction); mThreadwiseStore.MoveDstSliceWindow(step_sizes, positive_direction);
} }
}
private: private:
using ThreadBufferDesc = decltype(make_native_tensor_descriptor_packed(ThreadSliceLengths{})); using ThreadBufferDesc = decltype(make_native_tensor_descriptor_packed(ThreadSliceLengths{}));
...@@ -205,6 +191,9 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -205,6 +191,9 @@ struct BlockwiseGenericTensorSliceCopy_v4
DstAddressSpace, DstAddressSpace,
DstInMemOp>; DstInMemOp>;
static constexpr auto mThreadClusterDesc =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
ThreadwiseLoad mThreadwiseLoad; ThreadwiseLoad mThreadwiseLoad;
ThreadwiseStore mThreadwiseStore; ThreadwiseStore mThreadwiseStore;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment