softmax.cpp 7.92 KB
Newer Older
Khalique's avatar
Khalique committed
1
2
3
4
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/softmax.hpp>
5
#include <migraphx/gpu/device/reduce.hpp>
Khalique's avatar
Khalique committed
6
7
8
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
9
10
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
Khalique's avatar
Khalique committed
11
12
13
14
15
16

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

17
18
struct half2_sum
{
Shucai Xiao's avatar
Shucai Xiao committed
19
    MIGRAPHX_DEVICE_CONSTEXPR auto operator()(__half2 x, __half2 y) const { return __hadd2(x, y); }
20
21
22
23
24
25
};

inline __device__ __half2 hmax2(__half2 x, __half2 y)
{
    auto fx2 = __half22float2(x);
    auto fy2 = __half22float2(y);
Shucai Xiao's avatar
Shucai Xiao committed
26
27
    auto fx  = fx2.x > fy2.x ? fx2.x : fy2.x;
    auto fy  = fx2.y > fy2.y ? fx2.y : fy2.y;
28
29
30
31
32
    return __floats2half2_rn(fx, fy);
}

struct half2_max
{
Shucai Xiao's avatar
Shucai Xiao committed
33
    MIGRAPHX_DEVICE_CONSTEXPR auto operator()(__half2 x, __half2 y) const { return hmax2(x, y); }
34
35
36
};

// in_data is in shared memory
Shucai Xiao's avatar
Shucai Xiao committed
37
38
39
template <class Op>
__device__ __half2
block_reduce(__half2* buffer, index_int batch_item_num, index_int tid, index_int block_size, Op op)
40
{
41
    __syncthreads();
Shucai Xiao's avatar
Shucai Xiao committed
42
    for(index_int s = block_size; s > 0; s >>= 1)
43
    {
Shucai Xiao's avatar
Shucai Xiao committed
44
        if(tid < s and tid + s < batch_item_num)
45
        {
Shucai Xiao's avatar
Shucai Xiao committed
46
            buffer[tid] = op(buffer[tid], buffer[tid + s]);
47
48
49
50
        }
        __syncthreads();
    }

Shucai Xiao's avatar
Shucai Xiao committed
51
    auto lows2  = __low2half2(buffer[0]);
52
53
54
55
56
    auto highs2 = __high2half2(buffer[0]);

    return op(lows2, highs2);
}

Shucai Xiao's avatar
Shucai Xiao committed
57
58
__global__ void
softmax_kernel(void* data_in, index_int batch_item_num, index_int block_size, void* data_out)
59
{
Shucai Xiao's avatar
Shucai Xiao committed
60
    __half2* input  = reinterpret_cast<__half2*>(data_in);
61
62
63
64
65
    __half2* output = reinterpret_cast<__half2*>(data_out);
    batch_item_num /= 2;
    extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[];

    __half2* in_data_reduce = buffer2;
Shucai Xiao's avatar
Shucai Xiao committed
66
    __half2* in_data        = buffer2 + batch_item_num;
Shucai Xiao's avatar
Shucai Xiao committed
67
    int start               = blockIdx.x * batch_item_num;
68
    for(int i = threadIdx.x; i < batch_item_num; i += block_size)
69
    {
Shucai Xiao's avatar
Shucai Xiao committed
70
71
        auto d            = input[i + start];
        in_data[i]        = d;
72
73
74
        in_data_reduce[i] = d;
    }

Shucai Xiao's avatar
Shucai Xiao committed
75
76
77
    auto batch_max =
        block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_max{});

78
    for(int i = threadIdx.x; i < batch_item_num; i += block_size)
79
    {
Shucai Xiao's avatar
Shucai Xiao committed
80
        in_data[i]        = h2exp(__hsub2(in_data[i], batch_max));
81
82
83
        in_data_reduce[i] = in_data[i];
    }

Shucai Xiao's avatar
Shucai Xiao committed
84
85
    auto batch_sum =
        block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{});
86

87
    for(int i = threadIdx.x; i < batch_item_num; i += block_size)
88
89
90
91
92
93
    {
        output[i + start] = __h2div(in_data[i], batch_sum);
    }
}

// in_data is in shared memory
Shucai Xiao's avatar
Shucai Xiao committed
94
95
96
template <class Op>
__device__ __half
block_reduce2(__half* data, index_int batch_item_num, index_int tid, index_int block_size, Op op)
97
{
98
    __syncthreads();
Shucai Xiao's avatar
Shucai Xiao committed
99
    for(index_int s = block_size / 2; s > 0; s >>= 1)
100
    {
Shucai Xiao's avatar
Shucai Xiao committed
101
        if(tid < s and tid + s < batch_item_num)
102
        {
Shucai Xiao's avatar
Shucai Xiao committed
103
            data[tid] = op(data[tid], data[tid + s]);
104
105
106
107
108
109
110
        }
        __syncthreads();
    }

    return data[0];
}

Shucai Xiao's avatar
Shucai Xiao committed
111
112
__global__ void
softmax_kernel2(void* data_in, index_int batch_item_num, index_int block_size, void* data_out)
113
{
Shucai Xiao's avatar
Shucai Xiao committed
114
    __half* input  = reinterpret_cast<__half*>(data_in);
115
116
117
118
    __half* output = reinterpret_cast<__half*>(data_out);
    extern MIGRAPHX_DEVICE_SHARED __half buffer[];

    __half* in_data_reduce = buffer;
Shucai Xiao's avatar
Shucai Xiao committed
119
    __half* in_data        = buffer + batch_item_num;
Shucai Xiao's avatar
Shucai Xiao committed
120
    int start              = blockIdx.x * batch_item_num;
Shucai Xiao's avatar
Shucai Xiao committed
121
    for(int i = threadIdx.x; i < batch_item_num; i += block_size)
122
    {
Shucai Xiao's avatar
Shucai Xiao committed
123
124
        auto d            = input[i + start];
        in_data[i]        = d;
125
126
127
128
        in_data_reduce[i] = d;
    }

    auto batch_max = block_reduce2(in_data_reduce, batch_item_num, threadIdx.x, block_size, max{});
Shucai Xiao's avatar
Shucai Xiao committed
129
    for(int i = threadIdx.x; i < batch_item_num; i += block_size)
130
    {
Shucai Xiao's avatar
Shucai Xiao committed
131
        in_data[i]        = __float2half(::exp(__half2float(in_data[i]) - __half2float(batch_max)));
132
133
134
135
        in_data_reduce[i] = in_data[i];
    }

    auto batch_sum = block_reduce2(in_data_reduce, batch_item_num, threadIdx.x, block_size, sum{});
Shucai Xiao's avatar
Shucai Xiao committed
136
    for(int i = threadIdx.x; i < batch_item_num; i += block_size)
137
    {
Shucai Xiao's avatar
Shucai Xiao committed
138
        output[i + start] = __float2half(__half2float(in_data[i]) / __half2float(batch_sum));
139
140
141
    }
}

142
void softmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
Khalique's avatar
Khalique committed
143
{
144
145
    auto batch_lens          = result.get_shape().lens();
    index_int batch_item_num = batch_lens[axis];
146
    batch_lens[axis]         = 1;
147
    migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
Khalique's avatar
Khalique committed
148

Paul's avatar
Paul committed
149
    hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
150
        const index_int max_block_size = 128;
151
        const index_int block_size     = compute_block_size(batch_item_num, max_block_size);
152
153
154
155
156
        using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
        type init  = lowest();

        if(axis == batch_lens.size() - 1)
        {
157
            auto in_type = result.get_shape().type();
Shucai Xiao's avatar
Shucai Xiao committed
158
            if(in_type == shape::half_type and batch_item_num <= 1024)
159
            {
160
                auto half2_block_size = compute_block_size(batch_item_num, 1024);
Shucai Xiao's avatar
Shucai Xiao committed
161
162
                int block_num         = batch_shape.elements();
                int shared_size       = batch_item_num * 2 * result.get_shape().type_size();
Shucai Xiao's avatar
Shucai Xiao committed
163
                half2_block_size      = half2_block_size / 4;
Shucai Xiao's avatar
Shucai Xiao committed
164
165
                softmax_kernel<<<block_num, half2_block_size, shared_size, stream>>>(
                    arg.data(), batch_item_num, half2_block_size, result.data());
166
167
168
169
170
171
172
173
174
175
            }
            else
            {
                gs_launch(stream, batch_shape.elements() * block_size, block_size)(
                    [=](auto i, auto idx) __device__ {
                        auto start_loc = i / block_size * batch_item_num;
                        auto batch_max = block_reduce<max_block_size>(
                            idx, max{}, init, batch_item_num, [&](auto j) __device__ {
                                return input[start_loc + j];
                            });
176

177
178
179
180
181
                        auto batch_sum = block_reduce<max_block_size>(
                            idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
                                auto val = input[start_loc + j] - batch_max;
                                return ::exp(to_hip_type(val));
                            });
182

183
184
185
186
                        idx.local_stride(batch_item_num, [&](auto j) __device__ {
                            auto val              = input[start_loc + j] - batch_max;
                            output[start_loc + j] = ::exp(to_hip_type(val)) / batch_sum;
                        });
187
                    });
188
            }
189
190
191
192
193
194
195
196
197
198
199
        }
        else
        {
            gs_launch(stream, batch_shape.elements() * block_size, block_size)(
                [=](auto i, auto idx) __device__ {
                    auto data_idx  = batch.multi(i / block_size);
                    auto batch_max = block_reduce<max_block_size>(
                        idx, max{}, init, batch_item_num, [&](auto j) __device__ {
                            data_idx[axis] = j;
                            return input[data_idx];
                        });
200

201
202
203
204
205
206
207
208
209
210
211
212
213
214
                    auto batch_sum = block_reduce<max_block_size>(
                        idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
                            data_idx[axis] = j;
                            auto val       = input[data_idx] - batch_max;
                            return ::exp(to_hip_type(val));
                        });

                    idx.local_stride(batch_item_num, [&](auto j) __device__ {
                        data_idx[axis]   = j;
                        auto val         = input[data_idx] - batch_max;
                        output[data_idx] = ::exp(to_hip_type(val)) / batch_sum;
                    });
                });
        }
Khalique's avatar
Khalique committed
215
216
217
218
219
220
221
    });
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx