Commit 848ceeb3 authored by ltqin's avatar ltqin
Browse files

add PartitionedBlockwiseReduction2

parent 9bfe6591
...@@ -22,20 +22,26 @@ struct BlockwiseSoftmax_V1 ...@@ -22,20 +22,26 @@ struct BlockwiseSoftmax_V1
{ {
static_assert(MRepeat == 1, "Now MRepeat must equal 1"); static_assert(MRepeat == 1, "Now MRepeat must equal 1");
struct BlockToMKMap_M0_K_M1Adapt{ static constexpr index_t WaveSize = 64;
__host__ __device__ BlockToMKMap_M0_K_M1Adapt() = default;
__host__ __device__ constexpr auto CalculateBottomIndex(const index_t& idx_top) const{ struct BlockToMKMap_M0_K_M1Adapt
{
using ThreadClusterLengths_M_K = Sequence<MPerXDL, WaveSize / MPerXDL>; using ThreadClusterLengths_M_K = Sequence<MPerXDL, WaveSize / MPerXDL>;
using ThreadClusterArrangeOrder = Sequence<1, 0>; using ThreadClusterArrangeOrder = Sequence<1, 0>;
static constexpr auto thread_cluster_desc = __host__ __device__ BlockToMKMap_M0_K_M1Adapt() = default;
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
return thread_cluster_desc.CalculateBottomIndex(idx_top);
} }
} };
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr index_t MThreadSliceSize = 1; static constexpr index_t MThreadSliceSize = 1;
static constexpr index_t WaveSize = 64;
constexpr static auto in_thread_desc = make_naive_tensor_descriptor_packed( constexpr static auto in_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<RegSizePerXdlops>{})); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<RegSizePerXdlops>{}));
...@@ -54,15 +60,12 @@ struct BlockwiseSoftmax_V1 ...@@ -54,15 +60,12 @@ struct BlockwiseSoftmax_V1
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>; detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
using ThreadClusterLengths_M_K = Sequence<MPerXDL, WaveSize / MPerXDL>; using ThreadClusterLengths_M_K = Sequence<MPerXDL, WaveSize / MPerXDL>;
using ThreadClusterArrangeOrder = Sequence<1, 0>;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using BlockwiseMaxReduce = using BlockwiseMaxReduce =
PartitionedBlockwiseReduction2<AccDataType, PartitionedBlockwiseReduction2<AccDataType,
BlockSize, BlockSize,
ThreadClusterLengths_M_K, ThreadClusterLengths_M_K,
decltype(thread_cluster_desc), BlockToMKMap_M0_K_M1Adapt,
reduce::Max, reduce::Max,
false, // param ignored false, // param ignored
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>; detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
...@@ -71,7 +74,7 @@ struct BlockwiseSoftmax_V1 ...@@ -71,7 +74,7 @@ struct BlockwiseSoftmax_V1
PartitionedBlockwiseReduction2<AccDataType, PartitionedBlockwiseReduction2<AccDataType,
BlockSize, BlockSize,
ThreadClusterLengths_M_K, ThreadClusterLengths_M_K,
decltype(thread_cluster_desc), BlockToMKMap_M0_K_M1Adapt,
reduce::Add, reduce::Add,
false, // ignored false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>; detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment