softmax.cpp 7.89 KB
Newer Older
Khalique's avatar
Khalique committed
1
2
3
4
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/softmax.hpp>
5
#include <migraphx/gpu/device/reduce.hpp>
Khalique's avatar
Khalique committed
6
7
8
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
9
10
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
Khalique's avatar
Khalique committed
11
12
13
14
15
16

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

17
18
struct half2_sum
{
Shucai Xiao's avatar
Shucai Xiao committed
19
    MIGRAPHX_DEVICE_CONSTEXPR auto operator()(__half2 x, __half2 y) const { return __hadd2(x, y); }
20
21
22
23
24
25
};

inline __device__ __half2 hmax2(__half2 x, __half2 y)
{
    auto fx2 = __half22float2(x);
    auto fy2 = __half22float2(y);
Shucai Xiao's avatar
Shucai Xiao committed
26
27
    auto fx  = fx2.x > fy2.x ? fx2.x : fy2.x;
    auto fy  = fx2.y > fy2.y ? fx2.y : fy2.y;
28
29
30
31
32
    return __floats2half2_rn(fx, fy);
}

struct half2_max
{
Shucai Xiao's avatar
Shucai Xiao committed
33
    MIGRAPHX_DEVICE_CONSTEXPR auto operator()(__half2 x, __half2 y) const { return hmax2(x, y); }
34
35
36
};

// in_data is in shared memory
Shucai Xiao's avatar
Shucai Xiao committed
37
38
39
template <class Op>
__device__ __half2
block_reduce(__half2* buffer, index_int batch_item_num, index_int tid, index_int block_size, Op op)
40
{
41
    __syncthreads();
42
43
44
45
46
47
48
49
50
51
    for(index_int s = 1; s < block_size; s *= 2)
    {
        const index_int index = 2 * s * tid;
        if(index + s < batch_item_num)
        {
            buffer[index] = op(buffer[index], buffer[index + s]);
        }
        __syncthreads();
    }

Shucai Xiao's avatar
Shucai Xiao committed
52
    auto lows2  = __low2half2(buffer[0]);
53
54
55
56
57
    auto highs2 = __high2half2(buffer[0]);

    return op(lows2, highs2);
}

Shucai Xiao's avatar
Shucai Xiao committed
58
59
__global__ void
softmax_kernel(void* data_in, index_int batch_item_num, index_int block_size, void* data_out)
60
{
Shucai Xiao's avatar
Shucai Xiao committed
61
    __half2* input  = reinterpret_cast<__half2*>(data_in);
62
63
64
65
66
67
    __half2* output = reinterpret_cast<__half2*>(data_out);
    batch_item_num /= 2;
    int tid = blockDim.x * blockIdx.x + threadIdx.x;
    extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[];

    __half2* in_data_reduce = buffer2;
Shucai Xiao's avatar
Shucai Xiao committed
68
69
    __half2* in_data        = buffer2 + batch_item_num;
    int start               = tid / block_size * batch_item_num;
70
    for(int i = threadIdx.x; i < batch_item_num; i += block_size)
71
    {
Shucai Xiao's avatar
Shucai Xiao committed
72
73
        auto d            = input[i + start];
        in_data[i]        = d;
74
75
76
        in_data_reduce[i] = d;
    }

Shucai Xiao's avatar
Shucai Xiao committed
77
78
79
    auto batch_max =
        block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_max{});

80
    for(int i = threadIdx.x; i < batch_item_num; i += block_size)
81
    {
Shucai Xiao's avatar
Shucai Xiao committed
82
        in_data[i]        = h2exp(__hsub2(in_data[i], batch_max));
83
84
85
        in_data_reduce[i] = in_data[i];
    }

Shucai Xiao's avatar
Shucai Xiao committed
86
87
    auto batch_sum =
        block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{});
88

89
    for(int i = threadIdx.x; i < batch_item_num; i += block_size)
90
91
92
93
94
95
    {
        output[i + start] = __h2div(in_data[i], batch_sum);
    }
}

// in_data is in shared memory
Shucai Xiao's avatar
Shucai Xiao committed
96
97
98
template <class Op>
__device__ __half
block_reduce2(__half* data, index_int batch_item_num, index_int tid, index_int block_size, Op op)
99
{
100
    __syncthreads();
101
102
103
104
105
106
107
108
109
110
111
112
113
    for(index_int s = 1; s < block_size; s *= 2)
    {
        const index_int index = 2 * s * tid;
        if(index + s < batch_item_num)
        {
            data[index] = op(data[index], data[index + s]);
        }
        __syncthreads();
    }

    return data[0];
}

Shucai Xiao's avatar
Shucai Xiao committed
114
115
__global__ void
softmax_kernel2(void* data_in, index_int batch_item_num, index_int block_size, void* data_out)
116
{
Shucai Xiao's avatar
Shucai Xiao committed
117
    __half* input  = reinterpret_cast<__half*>(data_in);
118
119
120
121
    __half* output = reinterpret_cast<__half*>(data_out);
    extern MIGRAPHX_DEVICE_SHARED __half buffer[];

    __half* in_data_reduce = buffer;
Shucai Xiao's avatar
Shucai Xiao committed
122
    __half* in_data        = buffer + batch_item_num;
Shucai Xiao's avatar
Shucai Xiao committed
123
    int start              = blockIdx.x * batch_item_num;
Shucai Xiao's avatar
Shucai Xiao committed
124
    for(int i = threadIdx.x; i < batch_item_num; i += block_size)
125
    {
Shucai Xiao's avatar
Shucai Xiao committed
126
127
        auto d            = input[i + start];
        in_data[i]        = d;
128
129
130
131
        in_data_reduce[i] = d;
    }

    auto batch_max = block_reduce2(in_data_reduce, batch_item_num, threadIdx.x, block_size, max{});
Shucai Xiao's avatar
Shucai Xiao committed
132
    for(int i = threadIdx.x; i < batch_item_num; i += block_size)
133
    {
Shucai Xiao's avatar
Shucai Xiao committed
134
        in_data[i]        = __float2half(::exp(__half2float(in_data[i]) - __half2float(batch_max)));
135
136
137
138
        in_data_reduce[i] = in_data[i];
    }

    auto batch_sum = block_reduce2(in_data_reduce, batch_item_num, threadIdx.x, block_size, sum{});
Shucai Xiao's avatar
Shucai Xiao committed
139
    for(int i = threadIdx.x; i < batch_item_num; i += block_size)
140
    {
Shucai Xiao's avatar
Shucai Xiao committed
141
        output[i + start] = __float2half(__half2float(in_data[i]) / __half2float(batch_sum));
142
143
144
    }
}

145
void softmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
Khalique's avatar
Khalique committed
146
{
147
148
    auto batch_lens          = result.get_shape().lens();
    index_int batch_item_num = batch_lens[axis];
149
    batch_lens[axis]         = 1;
150
    migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
Khalique's avatar
Khalique committed
151

Paul's avatar
Paul committed
152
    hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
153
        const index_int max_block_size = 1024;
154
        const index_int block_size     = compute_block_size(batch_item_num, max_block_size);
155
156
157
158
159
        using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
        type init  = lowest();

        if(axis == batch_lens.size() - 1)
        {
160
            auto in_type = result.get_shape().type();
Shucai Xiao's avatar
Shucai Xiao committed
161
            if(in_type == shape::half_type and batch_item_num <= 2048)
162
            {
Shucai Xiao's avatar
Shucai Xiao committed
163
                int block_num   = batch_shape.elements();
164
                int shared_size = batch_item_num * 2 * result.get_shape().type_size();
165

166
                softmax_kernel<<<block_num, block_size, shared_size, stream>>>(
Shucai Xiao's avatar
Shucai Xiao committed
167
                    arg.data(), batch_item_num, block_size, result.data());
168
169
170
171
172
173
174
175
176
177
            }
            else
            {
                gs_launch(stream, batch_shape.elements() * block_size, block_size)(
                    [=](auto i, auto idx) __device__ {
                        auto start_loc = i / block_size * batch_item_num;
                        auto batch_max = block_reduce<max_block_size>(
                            idx, max{}, init, batch_item_num, [&](auto j) __device__ {
                                return input[start_loc + j];
                            });
178

179
180
181
182
183
                        auto batch_sum = block_reduce<max_block_size>(
                            idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
                                auto val = input[start_loc + j] - batch_max;
                                return ::exp(to_hip_type(val));
                            });
184

185
186
187
188
                        idx.local_stride(batch_item_num, [&](auto j) __device__ {
                            auto val              = input[start_loc + j] - batch_max;
                            output[start_loc + j] = ::exp(to_hip_type(val)) / batch_sum;
                        });
189
                    });
190
            }
191
192
193
194
195
196
197
198
199
200
201
        }
        else
        {
            gs_launch(stream, batch_shape.elements() * block_size, block_size)(
                [=](auto i, auto idx) __device__ {
                    auto data_idx  = batch.multi(i / block_size);
                    auto batch_max = block_reduce<max_block_size>(
                        idx, max{}, init, batch_item_num, [&](auto j) __device__ {
                            data_idx[axis] = j;
                            return input[data_idx];
                        });
202

203
204
205
206
207
208
209
210
211
212
213
214
215
216
                    auto batch_sum = block_reduce<max_block_size>(
                        idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
                            data_idx[axis] = j;
                            auto val       = input[data_idx] - batch_max;
                            return ::exp(to_hip_type(val));
                        });

                    idx.local_stride(batch_item_num, [&](auto j) __device__ {
                        data_idx[axis]   = j;
                        auto val         = input[data_idx] - batch_max;
                        output[data_idx] = ::exp(to_hip_type(val)) / batch_sum;
                    });
                });
        }
Khalique's avatar
Khalique committed
217
218
219
220
221
222
223
    });
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx