reduce.hpp 6.49 KB
Newer Older
ltqin's avatar
ltqin committed
1
// SPDX-License-Identifier: MIT
2
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
ltqin's avatar
ltqin committed
3
4
5
6

#pragma once

#include "ck_tile/core.hpp"
7
8
9
#include "ck_tile/device/ops/common.hpp"
#include "ck_tile/device/ops/reduce/block/block_reduce.hpp"
#include "ck_tile/device/ops/reduce/block/block_reduce2d_default_policy.hpp"
ltqin's avatar
ltqin committed
10
11
12

namespace ck_tile {

13
template <typename BlockWarps, // num warps along seq<M, N>
ltqin's avatar
ltqin committed
14
15
          typename BlockTile,  // block size, seq<M, N>
          typename WarpTile,   // warp size, seq<M, N>
16
17
          typename Vector>     // contiguous pixels(vector size) along seq<M, N>
struct Reduce2dShape
ltqin's avatar
ltqin committed
18
19
20
21
22
23
24
{
    static constexpr index_t Block_M = BlockTile::at(number<0>{});
    static constexpr index_t Block_N = BlockTile::at(number<1>{});

    static constexpr index_t Warp_M = WarpTile::at(number<0>{});
    static constexpr index_t Warp_N = WarpTile::at(number<1>{});

25
26
    static constexpr index_t Vector_M = Vector::at(number<0>{});
    static constexpr index_t Vector_N = Vector::at(number<1>{});
ltqin's avatar
ltqin committed
27
28
29
30

    static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{});
    static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{});

31
32
    static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
    static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
ltqin's avatar
ltqin committed
33
34
35
36

    static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
    static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    static constexpr index_t BlockSize =
        warpSize * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
};

template <typename XDataType_,
          typename ComputeDataType_,
          typename YDataType_,
          typename BlockShape_,
          typename ReduceOp_>
struct Reduce2dProblem
{
    using XDataType       = remove_cvref_t<XDataType_>;
    using ComputeDataType = remove_cvref_t<ComputeDataType_>;
    using YDataType       = remove_cvref_t<YDataType_>;
    using BlockShape      = remove_cvref_t<BlockShape_>;
    using ReduceOp        = ReduceOp_;

    static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
    static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
};

template <typename Problem_, typename Policy_ = BlockReduce2dDefaultPolicy>
struct Reduce
{
    using Problem = ck_tile::remove_cvref_t<Problem_>;
    using Policy  = ck_tile::remove_cvref_t<Policy_>;

    using XDataType       = ck_tile::remove_cvref_t<typename Problem::XDataType>;
    using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
    using YDataType       = ck_tile::remove_cvref_t<typename Problem::YDataType>;
ltqin's avatar
ltqin committed
67

68
69
70
#if 0
    CK_TILE_DEVICE void operator()(const XDataType* p_x, YDataType* p_y, index_t M, index_t N)
    const
ltqin's avatar
ltqin committed
71
    {
72
        using S = typename Problem::BlockShape;
ltqin's avatar
ltqin committed
73

74
75
        const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
            p_x, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
ltqin's avatar
ltqin committed
76

77
78
79
80
81
82
83
84
85
86
87
        const auto y_m = make_naive_tensor_view_packed<address_space_enum::global>(
            p_y, make_tuple(M), number<1>{});

        const auto iM = get_block_id() * S::Block_M;

        auto x_window = make_tile_window(x_m_n,
                                         make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
                                         {iM, 0},
                                         Policy::template MakeXBlockTileDistribution<Problem>());

        auto y_window = make_tile_window(y_m, make_tuple(number<S::Block_M>{}), {iM});
ltqin's avatar
ltqin committed
88
89
90

        const auto f_reduce = [](const auto& v0, const auto& v1) { return v0 + v1; };

91
        const XDataType reduce_init_value = 0;
ltqin's avatar
ltqin committed
92
93
94

        constexpr auto reduce_dims = sequence<1>{};

95
96
        auto y_compute = decltype(block_tile_reduce<ComputeDataType>(
            load_tile(x_window), reduce_dims, f_reduce, reduce_init_value)){};
ltqin's avatar
ltqin committed
97

98
        set_tile(y_compute, reduce_init_value);
ltqin's avatar
ltqin committed
99

100
101
        index_t num_n_tile_iteration =
            __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N));
ltqin's avatar
ltqin committed
102

103
        for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
ltqin's avatar
ltqin committed
104
        {
105
106
107
108
            const auto x = load_tile(x_window);
            block_tile_reduce(y_compute, x, reduce_dims, f_reduce);
            move_tile_window(x_window, {0, S::Block_N});
        }
ltqin's avatar
ltqin committed
109

110
111
112
113
114
115
116
117
        block_tile_reduce_sync(y_compute, f_reduce);

        store_tile(y_window, cast_tile<YDataType>(y_compute));
    }
#else
    CK_TILE_DEVICE void operator()(const XDataType* p_x, YDataType* p_y, index_t M, index_t N) const
    {
        using S = typename Problem::BlockShape;
ltqin's avatar
ltqin committed
118

119
120
        const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
            p_x, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
ltqin's avatar
ltqin committed
121

122
123
        const auto y_m = make_naive_tensor_view_packed<address_space_enum::global>(
            p_y, make_tuple(M), number<1>{});
ltqin's avatar
ltqin committed
124

125
        const auto iM = get_block_id() * S::Block_M;
ltqin's avatar
ltqin committed
126

127
128
129
130
        auto x_window = make_tile_window(x_m_n,
                                         make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
                                         {iM, 0},
                                         Policy::template MakeXBlockTileDistribution<Problem>());
ltqin's avatar
ltqin committed
131

132
        auto y_window = make_tile_window(y_m, make_tuple(number<S::Block_M>{}), {iM});
ltqin's avatar
ltqin committed
133

134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
        __shared__ char smem[Policy::template GetSmemSize<Problem>()];

        index_t num_n_tile_iteration =
            __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N));

        auto reduce_func         = typename Problem::ReduceOp{};
        auto block_reduce2d      = Policy::template GetBlockReduce2d<Problem>();
        auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
        auto block_reduce2d_cross_warp_sync =
            Policy::template GetBlockReduce2dCrossWarpSync<Problem>();

        using XTensorType = decltype(load_tile(x_window));
        auto y_compute    = block_reduce2d.template MakeYBlockTile<XTensorType>();
        set_tile(y_compute, reduce_func.template GetIdentityValue<ComputeDataType>());

        for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
        {
            const auto x = load_tile(x_window);
            block_reduce2d(x, y_compute, reduce_func);
            move_tile_window(x_window, {0, S::Block_N});
        }
ltqin's avatar
ltqin committed
155

156
157
        block_reduce2d_sync(y_compute, reduce_func);
        block_reduce2d_cross_warp_sync(y_compute, smem, reduce_func);
ltqin's avatar
ltqin committed
158

159
        store_tile(y_window, cast_tile<YDataType>(y_compute));
ltqin's avatar
ltqin committed
160
    }
161
#endif
ltqin's avatar
ltqin committed
162
163
164
};

} // namespace ck_tile