layernorm.cpp 4.78 KB
Newer Older
kahmed10's avatar
kahmed10 committed
1
2
3
4
5
6
7
8
9
10
#include <migraphx/gpu/device/layernorm.hpp>
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/device/pow.hpp>
#include <migraphx/gpu/device/fast_div.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

11
12
13
14
15
16
17
18
19
20
21
22
23
24
template <class T>
struct vector_type
{
};

template <class T, index_int N>
struct vector_type<vec<T, N>>
{
    using type = T;
};

template <class T>
using vector_type_t = typename vector_type<T>::type;

kahmed10's avatar
kahmed10 committed
25
26
// m = x - mean(x)
// m / sqrt(mean(m ^ 2) + 1e-12)
27
28
29
30
31
32
33

template <index_int N>
void layernorm_vec_impl(hipStream_t stream,
                        const argument& result,
                        const argument& arg1,
                        index_int nelements,
                        index_int relements)
kahmed10's avatar
kahmed10 committed
34
{
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    hip_vec_visit_all<N>(result, arg1)([&](auto output, auto input) {
        using value_type = typename decltype(input)::value_type;

        const auto relements_v           = relements / N;
        const std::size_t max_block_size = 256;
        const std::size_t block_size     = compute_block_size(relements_v, max_block_size);
        const std::size_t block_size_div = encode_divisor(block_size);
        assert(relements_v <= block_size);

        gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
            const auto out_idx   = fast_div(i, block_size_div);
            const auto base_idx  = out_idx * relements_v;
            const auto input_idx = base_idx + idx.local;
            const bool in_range  = idx.local < relements_v;

            auto mean = [&](auto z) {
                auto psum = block_reduce<max_block_size>(
                    idx, sum{}, value_type(0), relements_v, [=](auto) { return z; });
                vector_type_t<value_type> sum = 0;
                for(index_int k = 0; k < N; k++)
                    sum += psum[k];
                return sum / relements;

            };
kahmed10's avatar
kahmed10 committed
59

60
61
62
            // m = x - mean(x)
            value_type x = in_range ? input.data()[input_idx] : 0;
            value_type m = x - mean(x);
kahmed10's avatar
kahmed10 committed
63

64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
            // mean(m ^ 2) + 1e-12
            value_type r = mean(m * m) + value_type(1e-12);

            // rsqrt(mean(m ^ 2) + 1e-12)
            value_type d = 0;
            for(index_int k = 0; k < N; k++)
                d[k] = ::rsqrt(r[k]);
            // m * rsqrt(mean(m ^ 2) + 1e-12)
            if(in_range)
                output.data()[input_idx] = m * d;
        });
    });
}

void layernorm_impl(hipStream_t stream,
                    const argument& result,
                    const argument& arg1,
                    index_int nelements,
                    index_int relements)
{
kahmed10's avatar
kahmed10 committed
84
85
86
87
88
89
    hip_visit_all(result, arg1)([&](auto output, auto input) {
        using value_type = typename decltype(input)::value_type;

        const std::size_t max_block_size = 256;
        const std::size_t block_size     = compute_block_size(relements, max_block_size);
        const std::size_t block_size_div = encode_divisor(block_size);
90
        assert(relements <= block_size);
kahmed10's avatar
kahmed10 committed
91
92

        gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
93
94
95
96
            const auto out_idx   = fast_div(i, block_size_div);
            const auto base_idx  = out_idx * relements;
            const auto input_idx = base_idx + idx.local;
            const bool in_range  = idx.local < relements;
kahmed10's avatar
kahmed10 committed
97

98
99
100
101
102
103
104
105
            auto mean = [&](auto z) {
                return block_reduce<max_block_size>(idx,
                                                    sum{},
                                                    value_type(0),
                                                    relements,
                                                    [=](auto) { return in_range ? z : 0; }) /
                       relements;
            };
kahmed10's avatar
kahmed10 committed
106

107
108
109
            // m = x - mean(x)
            value_type x = in_range ? input.data()[input_idx] : 0;
            value_type m = x - mean(x);
kahmed10's avatar
kahmed10 committed
110

111
112
            // mean(m ^ 2) + 1e-12
            value_type r = mean(m * m) + 1e-12;
kahmed10's avatar
kahmed10 committed
113

114
115
116
            // m * rsqrt(mean(m ^ 2) + 1e-12)
            if(in_range)
                output.data()[input_idx] = m * ::rsqrt(r);
kahmed10's avatar
kahmed10 committed
117
118
119
120
        });
    });
}

121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
void layernorm(hipStream_t stream, const argument& result, const argument& arg1)
{
    auto relements    = arg1.get_shape().lens().back();
    auto nelements    = result.get_shape().elements() / relements;
    auto output_shape = result.get_shape();
    auto reduce_output_lens(output_shape.lens());
    reduce_output_lens.back() = 1;

    if((relements % 4) == 0)
        layernorm_vec_impl<4>(stream, result, arg1, nelements, relements);
    else if(relements < 256)
        layernorm_impl(stream, result, arg1, nelements, relements);
    else
        MIGRAPHX_THROW("No kernel for layernorm");
}

kahmed10's avatar
kahmed10 committed
137
138
139
140
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx