"docs/en_US/Overview.md" did not exist on "183763effecb80b47ccfe6963424e7fd269b94b7"
layernorm.cpp 2.34 KB
Newer Older
kahmed10's avatar
kahmed10 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include <migraphx/gpu/device/layernorm.hpp>
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/device/pow.hpp>
#include <migraphx/gpu/device/fast_div.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

// m = x - mean(x)
// m / sqrt(mean(m ^ 2) + 1e-12)
void layernorm(hipStream_t stream, const argument& result, const argument& arg1)
{
    auto relements = arg1.get_shape().lens().back();
    assert(relements <= 1024);
    auto nelements    = result.get_shape().elements() / relements;
    auto input_shape  = arg1.get_shape();
    auto output_shape = result.get_shape();
    auto reduce_output_lens(output_shape.lens());
    reduce_output_lens.back() = 1;

    std::vector<index_int> reduce_lens = get_reduce_lens(input_shape.lens(), reduce_output_lens);

    hip_visit_all(result, arg1)([&](auto output, auto input) {
        using value_type = typename decltype(input)::value_type;

        const std::size_t max_block_size = 256;
        const std::size_t block_size     = compute_block_size(relements, max_block_size);
        const std::size_t block_size_div = encode_divisor(block_size);

        gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
            const auto out_idx  = i / block_size;
            const auto base_idx = out_idx * relements;
            value_type x_data[4];
            auto x = [&](auto j) -> value_type& {
                return x_data[fast_div(j - idx.local, block_size_div)];
            };

            idx.local_stride(relements,
                             [&](auto j) __device__ { x(j) = input.data()[base_idx + j]; });

            auto m = block_reduce<max_block_size>(
                         idx, sum{}, 0, relements, [&](auto j) __device__ { return x(j); }) /
                     relements;

            idx.local_stride(relements, [&](auto j) __device__ { x(j) = x(j) - m; });

            auto r = block_reduce<max_block_size>(
                         idx, sum{}, 0, relements, [&](auto j) __device__ { return x(j) * x(j); }) /
                     relements;

            idx.local_stride(relements, [&](auto j) __device__ {
                output.data()[base_idx + j] = x(j) * ::rsqrt(r + 1e-12);
            });

        });

    });
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx