Commit bd22abb5 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent e131f6aa
...@@ -52,20 +52,23 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -52,20 +52,23 @@ struct BlockwiseGenericTensorSliceCopy_v4
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{}, is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window"); "wrong! threads should be mapped to cover entire slicing window");
// map threads to cluster static_assert(BlockSize >= mThreadClusterDesc.GetElementSize(),
constexpr auto thread_cluster_desc = "wrong! BlockSize too small");
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
const auto thread_cluster_id = if(BlockSize == mThreadClusterDesc.GetElementSize() or
thread_cluster_desc.CalculateClusterIndex(get_thread_local_1d_id()); get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
const auto thread_cluster_id =
mThreadClusterDesc.CalculateClusterIndex(get_thread_local_1d_id());
const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{}; const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{};
mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_id_begin); mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_id_begin);
mThreadwiseLoad.SetDstSliceOrigin(make_zero_array<index_t, nDim>()); mThreadwiseLoad.SetDstSliceOrigin(make_zero_array<index_t, nDim>());
mThreadwiseStore.SetSrcSliceOrigin(make_zero_array<index_t, nDim>()); mThreadwiseStore.SetSrcSliceOrigin(make_zero_array<index_t, nDim>());
mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin); mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin);
}
} }
__device__ static constexpr index_t GetThreadBufferSize() __device__ static constexpr index_t GetThreadBufferSize()
...@@ -80,22 +83,8 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -80,22 +83,8 @@ struct BlockwiseGenericTensorSliceCopy_v4
constexpr bool has_optimized_address_calculation = constexpr bool has_optimized_address_calculation =
decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation(); decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation();
constexpr auto thread_cluster_desc = if(BlockSize == mThreadClusterDesc.GetElementSize() or
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
if(BlockSize == thread_cluster_desc.GetElementSize())
{
// TODO: threadwise copy is still being tweaked
if(has_optimized_address_calculation)
{
mThreadwiseLoad.Run_optimized_src_address_calculation(p_block_src, p_thread_buffer);
}
else
{
mThreadwiseLoad.Run(p_block_src, p_thread_buffer);
}
}
else if(get_thread_local_1d_id() < thread_cluster_desc.GetElementSize())
{ {
// TODO: threadwise copy is still being tweaked // TODO: threadwise copy is still being tweaked
if(has_optimized_address_calculation) if(has_optimized_address_calculation)
...@@ -116,23 +105,8 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -116,23 +105,8 @@ struct BlockwiseGenericTensorSliceCopy_v4
constexpr bool has_optimized_address_calculation = constexpr bool has_optimized_address_calculation =
decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation(); decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation();
constexpr auto thread_cluster_desc = if(BlockSize == mThreadClusterDesc.GetElementSize() or
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
if(BlockSize == thread_cluster_desc.GetElementSize())
{
// TODO: threadwise copy is still being tweaked
if(has_optimized_address_calculation)
{
mThreadwiseStore.Run_optimized_dst_address_calculation(p_thread_buffer,
p_block_dst);
}
else
{
mThreadwiseStore.Run(p_thread_buffer, p_block_dst);
}
}
else if(get_thread_local_1d_id() < thread_cluster_desc.GetElementSize())
{ {
// TODO: threadwise copy is still being tweaked // TODO: threadwise copy is still being tweaked
if(has_optimized_address_calculation) if(has_optimized_address_calculation)
...@@ -158,10 +132,14 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -158,10 +132,14 @@ struct BlockwiseGenericTensorSliceCopy_v4
BlockSrcData p_thread_buffer[GetThreadBufferSize()]; BlockSrcData p_thread_buffer[GetThreadBufferSize()];
RunLoadThreadBuffer(p_block_src, p_thread_buffer); if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
RunLoadThreadBuffer(p_block_src, p_thread_buffer);
// if there is type conversion, it's done during store // if there is type conversion, it's done during store
RunStoreThreadBuffer(p_thread_buffer, p_block_dst); RunStoreThreadBuffer(p_thread_buffer, p_block_dst);
}
} }
template <typename T, bool PositiveDirection> template <typename T, bool PositiveDirection>
...@@ -169,7 +147,11 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -169,7 +147,11 @@ struct BlockwiseGenericTensorSliceCopy_v4
MoveSrcSliceWindow(const T& step_sizes, MoveSrcSliceWindow(const T& step_sizes,
integral_constant<bool, PositiveDirection> positive_direction) integral_constant<bool, PositiveDirection> positive_direction)
{ {
mThreadwiseLoad.MoveSrcSliceWindow(step_sizes, positive_direction); if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
mThreadwiseLoad.MoveSrcSliceWindow(step_sizes, positive_direction);
}
} }
template <typename T, bool PositiveDirection> template <typename T, bool PositiveDirection>
...@@ -177,7 +159,11 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -177,7 +159,11 @@ struct BlockwiseGenericTensorSliceCopy_v4
MoveDstSliceWindow(const T& step_sizes, MoveDstSliceWindow(const T& step_sizes,
integral_constant<bool, PositiveDirection> positive_direction) integral_constant<bool, PositiveDirection> positive_direction)
{ {
mThreadwiseStore.MoveDstSliceWindow(step_sizes, positive_direction); if(BlockSize == mThreadClusterDesc.GetElementSize() or
get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize())
{
mThreadwiseStore.MoveDstSliceWindow(step_sizes, positive_direction);
}
} }
private: private:
...@@ -205,6 +191,9 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -205,6 +191,9 @@ struct BlockwiseGenericTensorSliceCopy_v4
DstAddressSpace, DstAddressSpace,
DstInMemOp>; DstInMemOp>;
static constexpr auto mThreadClusterDesc =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
ThreadwiseLoad mThreadwiseLoad; ThreadwiseLoad mThreadwiseLoad;
ThreadwiseStore mThreadwiseStore; ThreadwiseStore mThreadwiseStore;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment