softmax.cpp 7.92 KB
Newer Older
Khalique's avatar
Khalique committed
1
2
3
4
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/softmax.hpp>
5
#include <migraphx/gpu/device/reduce.hpp>
Khalique's avatar
Khalique committed
6
7
8
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
9
10
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
Khalique's avatar
Khalique committed
11
12
13
14
15
16

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

17
18
struct half2_sum
{
Shucai Xiao's avatar
Shucai Xiao committed
19
    MIGRAPHX_DEVICE_CONSTEXPR auto operator()(__half2 x, __half2 y) const { return __hadd2(x, y); }
20
21
22
23
24
25
};

inline __device__ __half2 hmax2(__half2 x, __half2 y)
{
    auto fx2 = __half22float2(x);
    auto fy2 = __half22float2(y);
Shucai Xiao's avatar
Shucai Xiao committed
26
27
    auto fx  = fx2.x > fy2.x ? fx2.x : fy2.x;
    auto fy  = fx2.y > fy2.y ? fx2.y : fy2.y;
28
29
30
31
32
    return __floats2half2_rn(fx, fy);
}

struct half2_max
{
Shucai Xiao's avatar
Shucai Xiao committed
33
    MIGRAPHX_DEVICE_CONSTEXPR auto operator()(__half2 x, __half2 y) const { return hmax2(x, y); }
34
35
36
};

// in_data is in shared memory
Shucai Xiao's avatar
Shucai Xiao committed
37
38
39
template <class Op>
__device__ __half2
block_reduce(__half2* buffer, index_int batch_item_num, index_int tid, index_int block_size, Op op)
40
41
42
43
44
45
46
47
48
49
50
{
    for(index_int s = 1; s < block_size; s *= 2)
    {
        const index_int index = 2 * s * tid;
        if(index + s < batch_item_num)
        {
            buffer[index] = op(buffer[index], buffer[index + s]);
        }
        __syncthreads();
    }

Shucai Xiao's avatar
Shucai Xiao committed
51
    auto lows2  = __low2half2(buffer[0]);
52
53
54
55
56
    auto highs2 = __high2half2(buffer[0]);

    return op(lows2, highs2);
}

Shucai Xiao's avatar
Shucai Xiao committed
57
58
__global__ void
softmax_kernel(void* data_in, index_int batch_item_num, index_int block_size, void* data_out)
59
{
Shucai Xiao's avatar
Shucai Xiao committed
60
    __half2* input  = reinterpret_cast<__half2*>(data_in);
61
62
63
64
65
66
    __half2* output = reinterpret_cast<__half2*>(data_out);
    batch_item_num /= 2;
    int tid = blockDim.x * blockIdx.x + threadIdx.x;
    extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[];

    __half2* in_data_reduce = buffer2;
Shucai Xiao's avatar
Shucai Xiao committed
67
68
    __half2* in_data        = buffer2 + batch_item_num;
    int start               = tid / block_size * batch_item_num;
69
    for(int i = threadIdx.x; i < batch_item_num; i += block_size)
70
    {
Shucai Xiao's avatar
Shucai Xiao committed
71
72
        auto d            = input[i + start];
        in_data[i]        = d;
73
74
75
        in_data_reduce[i] = d;
    }

Shucai Xiao's avatar
Shucai Xiao committed
76
77
78
    auto batch_max =
        block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_max{});

79
    for(int i = threadIdx.x; i < batch_item_num; i += block_size)
80
    {
Shucai Xiao's avatar
Shucai Xiao committed
81
        in_data[i]        = h2exp(__hsub2(in_data[i], batch_max));
82
83
84
        in_data_reduce[i] = in_data[i];
    }

Shucai Xiao's avatar
Shucai Xiao committed
85
86
    auto batch_sum =
        block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{});
87

88
    for(int i = threadIdx.x; i < batch_item_num; i += block_size)
89
90
91
92
93
94
    {
        output[i + start] = __h2div(in_data[i], batch_sum);
    }
}

// in_data is in shared memory
Shucai Xiao's avatar
Shucai Xiao committed
95
96
97
template <class Op>
__device__ __half
block_reduce2(__half* data, index_int batch_item_num, index_int tid, index_int block_size, Op op)
98
99
100
101
102
103
104
105
106
107
108
109
110
111
{
    for(index_int s = 1; s < block_size; s *= 2)
    {
        const index_int index = 2 * s * tid;
        if(index + s < batch_item_num)
        {
            data[index] = op(data[index], data[index + s]);
        }
        __syncthreads();
    }

    return data[0];
}

Shucai Xiao's avatar
Shucai Xiao committed
112
113
__global__ void
softmax_kernel2(void* data_in, index_int batch_item_num, index_int block_size, void* data_out)
114
{
Shucai Xiao's avatar
Shucai Xiao committed
115
    __half* input  = reinterpret_cast<__half*>(data_in);
116
    __half* output = reinterpret_cast<__half*>(data_out);
Shucai Xiao's avatar
Shucai Xiao committed
117
    int tid        = blockDim.x * blockIdx.x + threadIdx.x;
118
119
120
    extern MIGRAPHX_DEVICE_SHARED __half buffer[];

    __half* in_data_reduce = buffer;
Shucai Xiao's avatar
Shucai Xiao committed
121
122
123
    __half* in_data        = buffer + batch_item_num;
    int start              = tid / block_size * batch_item_num;
    for(int i = threadIdx.x; i < batch_item_num; i += block_size)
124
    {
Shucai Xiao's avatar
Shucai Xiao committed
125
126
        auto d            = input[i + start];
        in_data[i]        = d;
127
128
129
130
        in_data_reduce[i] = d;
    }

    auto batch_max = block_reduce2(in_data_reduce, batch_item_num, threadIdx.x, block_size, max{});
Shucai Xiao's avatar
Shucai Xiao committed
131
    for(int i = threadIdx.x; i < batch_item_num; i += block_size)
132
133
    {

Shucai Xiao's avatar
Shucai Xiao committed
134
        in_data[i]        = __float2half(::exp(__half2float(in_data[i]) - __half2float(batch_max)));
135
136
137
138
139
        in_data_reduce[i] = in_data[i];
    }

    auto batch_sum = block_reduce2(in_data_reduce, batch_item_num, threadIdx.x, block_size, sum{});

Shucai Xiao's avatar
Shucai Xiao committed
140
    for(int i = threadIdx.x; i < batch_item_num; i += block_size)
141
    {
Shucai Xiao's avatar
Shucai Xiao committed
142
        output[i + start] = __float2half(__half2float(in_data[i]) / __half2float(batch_sum));
143
144
145
    }
}

146
void softmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
Khalique's avatar
Khalique committed
147
{
148
149
    auto batch_lens          = result.get_shape().lens();
    index_int batch_item_num = batch_lens[axis];
150
    batch_lens[axis]         = 1;
151
    migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
Khalique's avatar
Khalique committed
152

Paul's avatar
Paul committed
153
    hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
154
        const index_int max_block_size = 128;
155
        const index_int block_size     = compute_block_size(batch_item_num, max_block_size);
156
157
158
159
160
        using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
        type init  = lowest();

        if(axis == batch_lens.size() - 1)
        {
161
            auto in_type = result.get_shape().type();
Shucai Xiao's avatar
Shucai Xiao committed
162
            if(in_type == shape::half_type and batch_item_num <= 2048)
163
            {
Shucai Xiao's avatar
Shucai Xiao committed
164
                int block_num   = batch_shape.elements();
165
                int shared_size = batch_item_num * 2 * result.get_shape().type_size();
166
                softmax_kernel<<<block_num, block_size, shared_size, stream>>>(
Shucai Xiao's avatar
Shucai Xiao committed
167
                    arg.data(), batch_item_num, block_size, result.data());
168
169
170
171
172
173
174
175
176
177
            }
            else
            {
                gs_launch(stream, batch_shape.elements() * block_size, block_size)(
                    [=](auto i, auto idx) __device__ {
                        auto start_loc = i / block_size * batch_item_num;
                        auto batch_max = block_reduce<max_block_size>(
                            idx, max{}, init, batch_item_num, [&](auto j) __device__ {
                                return input[start_loc + j];
                            });
178

179
180
181
182
183
                        auto batch_sum = block_reduce<max_block_size>(
                            idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
                                auto val = input[start_loc + j] - batch_max;
                                return ::exp(to_hip_type(val));
                            });
184

185
186
187
188
                        idx.local_stride(batch_item_num, [&](auto j) __device__ {
                            auto val              = input[start_loc + j] - batch_max;
                            output[start_loc + j] = ::exp(to_hip_type(val)) / batch_sum;
                        });
189
                    });
190
            }
191
192
193
194
195
196
197
198
199
200
201
        }
        else
        {
            gs_launch(stream, batch_shape.elements() * block_size, block_size)(
                [=](auto i, auto idx) __device__ {
                    auto data_idx  = batch.multi(i / block_size);
                    auto batch_max = block_reduce<max_block_size>(
                        idx, max{}, init, batch_item_num, [&](auto j) __device__ {
                            data_idx[axis] = j;
                            return input[data_idx];
                        });
202

203
204
205
206
207
208
209
210
211
212
213
214
215
216
                    auto batch_sum = block_reduce<max_block_size>(
                        idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
                            data_idx[axis] = j;
                            auto val       = input[data_idx] - batch_max;
                            return ::exp(to_hip_type(val));
                        });

                    idx.local_stride(batch_item_num, [&](auto j) __device__ {
                        data_idx[axis]   = j;
                        auto val         = input[data_idx] - batch_max;
                        output[data_idx] = ::exp(to_hip_type(val)) / batch_sum;
                    });
                });
        }
Khalique's avatar
Khalique committed
217
218
219
220
221
222
223
    });
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx