softmax.cpp 8.17 KB
Newer Older
Khalique's avatar
Khalique committed
1
2
3
4
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/softmax.hpp>
5
#include <migraphx/gpu/device/reduce.hpp>
Khalique's avatar
Khalique committed
6
7
8
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
9
10
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
Khalique's avatar
Khalique committed
11
12
13
14
15
16

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

17
18
struct half2_sum
{
Shucai Xiao's avatar
Shucai Xiao committed
19
    MIGRAPHX_DEVICE_CONSTEXPR auto operator()(__half2 x, __half2 y) const { return __hadd2(x, y); }
20
21
22
23
24
25
};

inline __device__ __half2 hmax2(__half2 x, __half2 y)
{
    auto fx2 = __half22float2(x);
    auto fy2 = __half22float2(y);
Shucai Xiao's avatar
Shucai Xiao committed
26
27
    auto fx  = fx2.x > fy2.x ? fx2.x : fy2.x;
    auto fy  = fx2.y > fy2.y ? fx2.y : fy2.y;
28
29
30
31
32
    return __floats2half2_rn(fx, fy);
}

struct half2_max
{
Shucai Xiao's avatar
Shucai Xiao committed
33
    MIGRAPHX_DEVICE_CONSTEXPR auto operator()(__half2 x, __half2 y) const { return hmax2(x, y); }
34
35
36
};

// in_data is in shared memory
Shucai Xiao's avatar
Shucai Xiao committed
37
38
39
template <class Op>
__device__ __half2
block_reduce(__half2* buffer, index_int batch_item_num, index_int tid, index_int block_size, Op op)
40
41
42
43
44
45
46
47
48
49
50
{
    for(index_int s = 1; s < block_size; s *= 2)
    {
        const index_int index = 2 * s * tid;
        if(index + s < batch_item_num)
        {
            buffer[index] = op(buffer[index], buffer[index + s]);
        }
        __syncthreads();
    }

Shucai Xiao's avatar
Shucai Xiao committed
51
    auto lows2  = __low2half2(buffer[0]);
52
53
54
55
56
    auto highs2 = __high2half2(buffer[0]);

    return op(lows2, highs2);
}

Shucai Xiao's avatar
Shucai Xiao committed
57
58
__global__ void
softmax_kernel(void* data_in, index_int batch_item_num, index_int block_size, void* data_out)
59
{
Shucai Xiao's avatar
Shucai Xiao committed
60
    __half2* input  = reinterpret_cast<__half2*>(data_in);
61
62
63
64
65
66
    __half2* output = reinterpret_cast<__half2*>(data_out);
    batch_item_num /= 2;
    int tid = blockDim.x * blockIdx.x + threadIdx.x;
    extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[];

    __half2* in_data_reduce = buffer2;
Shucai Xiao's avatar
Shucai Xiao committed
67
68
    __half2* in_data        = buffer2 + batch_item_num;
    int start               = tid / block_size * batch_item_num;
69
    for(int i = threadIdx.x; i < batch_item_num; i += block_size)
70
    {
Shucai Xiao's avatar
Shucai Xiao committed
71
72
        auto d            = input[i + start];
        in_data[i]        = d;
73
74
75
        in_data_reduce[i] = d;
    }

Shucai Xiao's avatar
Shucai Xiao committed
76
77
78
    auto batch_max =
        block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_max{});

79
    for(int i = threadIdx.x; i < batch_item_num; i += block_size)
80
    {
Shucai Xiao's avatar
Shucai Xiao committed
81
        in_data[i]        = h2exp(__hsub2(in_data[i], batch_max));
82
83
84
        in_data_reduce[i] = in_data[i];
    }

Shucai Xiao's avatar
Shucai Xiao committed
85
86
    auto batch_sum =
        block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{});
87

88
    for(int i = threadIdx.x; i < batch_item_num; i += block_size)
89
90
91
92
93
94
    {
        output[i + start] = __h2div(in_data[i], batch_sum);
    }
}

// in_data is in shared memory
Shucai Xiao's avatar
Shucai Xiao committed
95
96
97
template <class Op>
__device__ __half
block_reduce2(__half* data, index_int batch_item_num, index_int tid, index_int block_size, Op op)
98
99
100
101
102
103
104
105
106
107
108
109
110
111
{
    for(index_int s = 1; s < block_size; s *= 2)
    {
        const index_int index = 2 * s * tid;
        if(index + s < batch_item_num)
        {
            data[index] = op(data[index], data[index + s]);
        }
        __syncthreads();
    }

    return data[0];
}

Shucai Xiao's avatar
Shucai Xiao committed
112
113
__global__ void
softmax_kernel2(void* data_in, index_int batch_item_num, index_int block_size, void* data_out)
114
{
Shucai Xiao's avatar
Shucai Xiao committed
115
    __half* input  = reinterpret_cast<__half*>(data_in);
116
117
118
119
    __half* output = reinterpret_cast<__half*>(data_out);
    extern MIGRAPHX_DEVICE_SHARED __half buffer[];

    __half* in_data_reduce = buffer;
Shucai Xiao's avatar
Shucai Xiao committed
120
    __half* in_data        = buffer + batch_item_num;
Shucai Xiao's avatar
Shucai Xiao committed
121
    int start              = blockIdx.x * batch_item_num;
Shucai Xiao's avatar
Shucai Xiao committed
122
    for(int i = threadIdx.x; i < batch_item_num; i += block_size)
123
    {
Shucai Xiao's avatar
Shucai Xiao committed
124
125
        auto d            = input[i + start];
        in_data[i]        = d;
126
        in_data_reduce[i] = d;
Shucai Xiao's avatar
Shucai Xiao committed
127
        // printf("blockIdx = %d, ori_val = %f\n", start, __half2float(d));
128
129
130
    }

    auto batch_max = block_reduce2(in_data_reduce, batch_item_num, threadIdx.x, block_size, max{});
Shucai Xiao's avatar
Shucai Xiao committed
131
    // printf("blockIdx = %d, batch_max = %f\n", start, __half2float(batch_max));
Shucai Xiao's avatar
Shucai Xiao committed
132
    for(int i = threadIdx.x; i < batch_item_num; i += block_size)
133
    {
Shucai Xiao's avatar
Shucai Xiao committed
134
        in_data[i]        = __float2half(::exp(__half2float(in_data[i]) - __half2float(batch_max)));
135
        in_data_reduce[i] = in_data[i];
Shucai Xiao's avatar
Shucai Xiao committed
136
        // printf("blockIdx = %d, exp_val = %f\n", start, __half2float(in_data[i]));
137
138
139
    }

    auto batch_sum = block_reduce2(in_data_reduce, batch_item_num, threadIdx.x, block_size, sum{});
Shucai Xiao's avatar
Shucai Xiao committed
140
    // printf("blockIdx = %d, batch_sum = %f\n", start, __half2float(batch_sum));
141

Shucai Xiao's avatar
Shucai Xiao committed
142
    for(int i = threadIdx.x; i < batch_item_num; i += block_size)
143
    {
Shucai Xiao's avatar
Shucai Xiao committed
144
        output[i + start] = __float2half(__half2float(in_data[i]) / __half2float(batch_sum));
145
146
147
    }
}

148
void softmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
Khalique's avatar
Khalique committed
149
{
150
151
    auto batch_lens          = result.get_shape().lens();
    index_int batch_item_num = batch_lens[axis];
152
    batch_lens[axis]         = 1;
153
    migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
Khalique's avatar
Khalique committed
154

Paul's avatar
Paul committed
155
    hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
156
        const index_int max_block_size = 128;
157
        const index_int block_size     = compute_block_size(batch_item_num, max_block_size);
158
159
160
161
162
        using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
        type init  = lowest();

        if(axis == batch_lens.size() - 1)
        {
163
            auto in_type = result.get_shape().type();
Shucai Xiao's avatar
Shucai Xiao committed
164
            if(in_type == shape::half_type and batch_item_num <= 2048)
165
            {
Shucai Xiao's avatar
Shucai Xiao committed
166
                int block_num   = batch_shape.elements();
167
                int shared_size = batch_item_num * 2 * result.get_shape().type_size();
168
                softmax_kernel<<<block_num, block_size, shared_size, stream>>>(
Shucai Xiao's avatar
Shucai Xiao committed
169
                    arg.data(), batch_item_num, block_size, result.data());
170
171
172
173
174
175
176
177
178
179
            }
            else
            {
                gs_launch(stream, batch_shape.elements() * block_size, block_size)(
                    [=](auto i, auto idx) __device__ {
                        auto start_loc = i / block_size * batch_item_num;
                        auto batch_max = block_reduce<max_block_size>(
                            idx, max{}, init, batch_item_num, [&](auto j) __device__ {
                                return input[start_loc + j];
                            });
180

181
182
183
184
185
                        auto batch_sum = block_reduce<max_block_size>(
                            idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
                                auto val = input[start_loc + j] - batch_max;
                                return ::exp(to_hip_type(val));
                            });
186

187
188
189
190
                        idx.local_stride(batch_item_num, [&](auto j) __device__ {
                            auto val              = input[start_loc + j] - batch_max;
                            output[start_loc + j] = ::exp(to_hip_type(val)) / batch_sum;
                        });
191
                    });
192
            }
193
194
195
196
197
198
199
200
201
202
203
        }
        else
        {
            gs_launch(stream, batch_shape.elements() * block_size, block_size)(
                [=](auto i, auto idx) __device__ {
                    auto data_idx  = batch.multi(i / block_size);
                    auto batch_max = block_reduce<max_block_size>(
                        idx, max{}, init, batch_item_num, [&](auto j) __device__ {
                            data_idx[axis] = j;
                            return input[data_idx];
                        });
204

205
206
207
208
209
210
211
212
213
214
215
216
217
218
                    auto batch_sum = block_reduce<max_block_size>(
                        idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
                            data_idx[axis] = j;
                            auto val       = input[data_idx] - batch_max;
                            return ::exp(to_hip_type(val));
                        });

                    idx.local_stride(batch_item_num, [&](auto j) __device__ {
                        data_idx[axis]   = j;
                        auto val         = input[data_idx] - batch_max;
                        output[data_idx] = ::exp(to_hip_type(val)) / batch_sum;
                    });
                });
        }
Khalique's avatar
Khalique committed
219
220
221
222
223
224
225
    });
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx