prefix_scan_sum.cpp 4.87 KB
Newer Older
1
2
3
#include <migraphx/gpu/device/prefix_scan_sum.hpp>
#include <migraphx/gpu/device/scan.hpp>
#include <migraphx/gpu/device/reduce_ops.hpp>
4
#include <migraphx/gpu/device/reduce.hpp>
5
6
7
8
9
10
11
#include <migraphx/gpu/device/types.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

12
13
14
15
16
17
void prefix_scan_sum(hipStream_t stream,
                     const argument& result,
                     const argument& arg,
                     int32_t axis,
                     bool exclusive,
                     bool reverse)
18
{
19
20
21
22
23
    const index_int max_block_size = 256;
    const index_int n              = arg.get_shape().lens()[axis];
    auto rlens                     = result.get_shape().lens();
    rlens[axis]                    = 1;

24
25
    hip_visit_all(result, arg, result.get_shape().with_lens(rlens))(
        [=](auto output, auto input, auto rshape) {
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
            const index_int block_size = compute_block_size(rshape.elements(), max_block_size);
            if(reverse and exclusive)
            {
                gs_launch(stream, rshape.elements() * block_size, block_size)(
                    [=](auto i, auto idx) __device__ {
                        const auto ridx  = rshape.multi(i / block_size);
                        auto compute_idx = [&](auto j) {
                            auto k  = ridx;
                            k[axis] = j;
                            return k;
                        };
                        block_scan<max_block_size>(
                            idx,
                            sum{},
                            0,
                            n,
                            reverse_scan(n, [&](auto j) { return input[compute_idx(j)]; }),
                            reverse_scan(n, [&](auto j, auto x) {
                                if(j == n - 1)
                                    output[compute_idx(j)] = 0;
                                if(j > 0)
                                    output[compute_idx(j - 1)] = x;
                            }));
                    });
            }
            else if(reverse)
            {
                gs_launch(stream, rshape.elements() * block_size, block_size)(
                    [=](auto i, auto idx) __device__ {
                        const auto ridx  = rshape.multi(i / block_size);
                        auto compute_idx = [&](auto j) {
                            auto k  = ridx;
                            k[axis] = j;
                            return k;
                        };
                        block_scan<max_block_size>(
                            idx,
                            sum{},
                            0,
                            n,
                            reverse_scan(n, [&](auto j) { return input[compute_idx(j)]; }),
                            reverse_scan(n, [&](auto j, auto x) { output[compute_idx(j)] = x; }));
                    });
            }
            else if(exclusive)
            {
                gs_launch(stream, rshape.elements() * block_size, block_size)(
                    [=](auto i, auto idx) __device__ {
                        const auto ridx  = rshape.multi(i / block_size);
                        auto compute_idx = [&](auto j) {
                            auto k  = ridx;
                            k[axis] = j;
                            return k;
                        };
                        block_scan<max_block_size>(
                            idx,
                            sum{},
                            0,
                            n,
                            [&](auto j) { return input[compute_idx(j)]; },
                            [&](auto j, auto x) {
                                auto k = j + 1;
                                if(j == 0)
                                    output[compute_idx(0)] = 0;
                                if(k < n)
                                    output[compute_idx(k)] = x;
                            });
                    });
            }
            else
            {
                gs_launch(stream, rshape.elements() * block_size, block_size)(
                    [=](auto i, auto idx) __device__ {
                        const auto ridx  = rshape.multi(i / block_size);
                        auto compute_idx = [&](auto j) {
                            auto k  = ridx;
                            k[axis] = j;
                            return k;
                        };
                        block_scan<max_block_size>(
                            idx,
                            sum{},
                            0,
                            n,
                            [&](auto j) { return input[compute_idx(j)]; },
                            [&](auto j, auto x) { output[compute_idx(j)] = x; });
                    });
            }
114
115
116
117
118
119
120
        });
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx