"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "ab6633295c7dd26aa38a80982d03be14c59b2429"
Commit 848ceeb3 authored by ltqin's avatar ltqin
Browse files

add PartitionedBlockwiseReduction2

parent 9bfe6591
......@@ -22,20 +22,26 @@ struct BlockwiseSoftmax_V1
{
static_assert(MRepeat == 1, "Now MRepeat must equal 1");
struct BlockToMKMap_M0_K_M1Adapt{
static constexpr index_t WaveSize = 64;
struct BlockToMKMap_M0_K_M1Adapt
{
using ThreadClusterLengths_M_K = Sequence<MPerXDL, WaveSize / MPerXDL>;
using ThreadClusterArrangeOrder = Sequence<1, 0>;
__host__ __device__ BlockToMKMap_M0_K_M1Adapt() = default;
__host__ __device__ constexpr auto CalculateBottomIndex(const index_t& idx_top) const{
using ThreadClusterLengths_M_K = Sequence<MPerXDL, WaveSize / MPerXDL>;
using ThreadClusterArrangeOrder = Sequence<1, 0>;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
return thread_cluster_desc.CalculateBottomIndex(idx_top);
}
}
};
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t MThreadSliceSize = 1;
static constexpr index_t WaveSize = 64;
constexpr static auto in_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<RegSizePerXdlops>{}));
......@@ -53,28 +59,25 @@ struct BlockwiseSoftmax_V1
false, // param ignored
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
using ThreadClusterLengths_M_K = Sequence<MPerXDL, WaveSize / MPerXDL>;
using ThreadClusterArrangeOrder = Sequence<1, 0>;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadClusterLengths_M_K = Sequence<MPerXDL, WaveSize / MPerXDL>;
using BlockwiseMaxReduce =
PartitionedBlockwiseReduction2<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
decltype(thread_cluster_desc),
reduce::Max,
false, // param ignored
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
BlockSize,
ThreadClusterLengths_M_K,
BlockToMKMap_M0_K_M1Adapt,
reduce::Max,
false, // param ignored
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>;
using BlockwiseSumReduce =
PartitionedBlockwiseReduction2<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
decltype(thread_cluster_desc),
reduce::Add,
false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
BlockSize,
ThreadClusterLengths_M_K,
BlockToMKMap_M0_K_M1Adapt,
reduce::Add,
false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
using ThreadwiseSumReduce =
ThreadwiseReduction<AccDataType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment