Commit 9bfe6591 authored by ltqin's avatar ltqin
Browse files

change name form c_thread_buffer to in_thread_buffer

parent cec9e840
...@@ -22,6 +22,15 @@ struct BlockwiseSoftmax_V1 ...@@ -22,6 +22,15 @@ struct BlockwiseSoftmax_V1
{ {
static_assert(MRepeat == 1, "Now MRepeat must equal 1"); static_assert(MRepeat == 1, "Now MRepeat must equal 1");
struct BlockToMKMap_M0_K_M1Adapt{
__host__ __device__ BlockToMKMap_M0_K_M1Adapt() = default;
__host__ __device__ constexpr auto CalculateBottomIndex(const index_t& idx_top) const{
using ThreadClusterLengths_M_K = Sequence<MPerXDL, WaveSize / MPerXDL>;
using ThreadClusterArrangeOrder = Sequence<1, 0>;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
}
}
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
...@@ -46,20 +55,23 @@ struct BlockwiseSoftmax_V1 ...@@ -46,20 +55,23 @@ struct BlockwiseSoftmax_V1
using ThreadClusterLengths_M_K = Sequence<MPerXDL, WaveSize / MPerXDL>; using ThreadClusterLengths_M_K = Sequence<MPerXDL, WaveSize / MPerXDL>;
using ThreadClusterArrangeOrder = Sequence<1, 0>; using ThreadClusterArrangeOrder = Sequence<1, 0>;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using BlockwiseMaxReduce = using BlockwiseMaxReduce =
PartitionedBlockwiseReduction<AccDataType, PartitionedBlockwiseReduction2<AccDataType,
BlockSize, BlockSize,
ThreadClusterLengths_M_K, ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder, decltype(thread_cluster_desc),
reduce::Max, reduce::Max,
false, // param ignored false, // param ignored
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>; detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
using BlockwiseSumReduce = using BlockwiseSumReduce =
PartitionedBlockwiseReduction<AccDataType, PartitionedBlockwiseReduction2<AccDataType,
BlockSize, BlockSize,
ThreadClusterLengths_M_K, ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder, decltype(thread_cluster_desc),
reduce::Add, reduce::Add,
false, // ignored false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>; detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
......
...@@ -82,6 +82,78 @@ struct PartitionedBlockwiseReduction ...@@ -82,6 +82,78 @@ struct PartitionedBlockwiseReduction
}; };
}; };
// clang-format off
// Assume:
// 1) work_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
// 2) work_buffer has AccDataType elements, and space size is no less than BlockSize
// 3) in_out_value is the input data in vgpr from each thread
// 4) in_out_value is the over-written reduced output in vgpr for each thread
// clang-format on
template <typename AccDataType,
index_t BlockSize,
typename ThreadClusterLengths_M_K,
typename ThreadClusterDesc,
typename OpReduce,
bool PropagateNan,
typename Accumulation =
detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>>
struct PartitionedBlockwiseReduction2
{
static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
"The product of cluster lengths should be same as BlockSize!");
static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0);
static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1);
static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements");
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<BufferLength_M>{}, Number<BufferLength_K>{}));
static constexpr auto thread_cluster_desc = ThreadClusterDesc{};
template <typename BufferType>
__device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value)
{
static_assert(is_same<typename BufferType::type, AccDataType>{},
"Buffer data type should be consistent as AccDataType!");
constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
work_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value;
__syncthreads();
static_for<0, cluster_len_shift, 1>{}([&](auto I) {
constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I());
if(thread_k_cluster_id < indOffset)
{
index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx);
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
make_tuple(0, indOffset));
AccDataType opData1 = work_buffer[offset1];
AccDataType opData2 = work_buffer[offset2];
Accumulation::Calculate(opData1, opData2);
work_buffer(offset1) = opData1;
}
__syncthreads();
});
index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
in_out_value = work_buffer[offset];
};
};
// clang-format off // clang-format off
// Assume: // Assume:
// 1) work_val_buffer/work_idx_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data // 1) work_val_buffer/work_idx_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment