softmax.cpp 5.29 KB
Newer Older
Khalique's avatar
Khalique committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/softmax.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/hip.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

Shucai Xiao's avatar
Shucai Xiao committed
15
16
17
template <class T>
__device__ void
reduce_max(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num)
18
19
{
    auto stride = (item_num + 1) / 2;
Shucai Xiao's avatar
Shucai Xiao committed
20
    while(true)
21
22
23
    {
        if(thr_idx + stride < item_num)
        {
Shucai Xiao's avatar
Shucai Xiao committed
24
25
            data_ptr[thr_idx] =
                ::max(to_hip_type(data_ptr[thr_idx]), to_hip_type(data_ptr[thr_idx + stride]));
26
27
        }
        __syncthreads();
Shucai Xiao's avatar
Shucai Xiao committed
28
29
        item_num = stride;
        stride   = (stride + 1) / 2;
30
31
32
33
34
35
36

        if(item_num == 1)
            break;
    }

    if(thr_idx == 0)
    {
Shucai Xiao's avatar
Shucai Xiao committed
37
38
        data_ptr[block_size] =
            (data_ptr[0] < data_ptr[block_size]) ? data_ptr[block_size] : data_ptr[0];
39
40
41
42
43
    }

    __syncthreads();
}

Shucai Xiao's avatar
Shucai Xiao committed
44
45
46
template <class T>
__device__ void
reduce_sum(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num)
47
48
{
    auto stride = (item_num + 1) / 2;
Shucai Xiao's avatar
Shucai Xiao committed
49
    while(true)
50
51
52
53
54
55
    {
        if(thr_idx + stride < item_num)
        {
            data_ptr[thr_idx] += data_ptr[thr_idx + stride];
        }
        __syncthreads();
Shucai Xiao's avatar
Shucai Xiao committed
56
57
        item_num = stride;
        stride   = (stride + 1) / 2;
58
59
60
61
62
63
64
65
66
67
68
69
70

        if(item_num == 1)
            break;
    }

    if(thr_idx == 0)
    {
        data_ptr[block_size + 1] += data_ptr[0];
    }

    __syncthreads();
}

Shucai Xiao's avatar
Shucai Xiao committed
71
void softmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
Khalique's avatar
Khalique committed
72
{
73
    auto lens        = result.get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
74
75
    auto batch_lens  = lens;
    size_t n_dims    = lens[axis];
76
    batch_lens[axis] = 1;
77
    migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
Khalique's avatar
Khalique committed
78

79
    visit_all(result, arg)([&](auto output, auto input) {
Khalique's avatar
Khalique committed
80
81
        const auto* input_ptr = device_cast(input.data());
        auto* output_ptr      = device_cast(output.data());
82
83
        visit_tensor_size(batch_shape.lens().size(), [&](auto n_dim) {
            hip_tensor_descriptor<n_dim> desc_batch(batch_shape);
84
            hip_tensor_descriptor<n_dim> desc_data(result.get_shape());
Khalique's avatar
Khalique committed
85

86
            // use one block for items in one batch.
87
            const size_t max_block_size = 1024;
Shucai Xiao's avatar
Shucai Xiao committed
88
89
            size_t block_size           = 1;
            while(block_size < max_block_size and block_size < n_dims)
90
91
92
93
            {
                block_size *= 2;
            }

94
95
96
97
98
99
100
101
            launch(
                stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
                size_t thr_idx = idx.local;
                size_t blk_idx = idx.group;
                using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;

                // all data can be loaded to the lds once, so all operations are
                // done in lds
102
                MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 2];
103
                auto batch_idx = desc_batch.multi(blk_idx);
Shucai Xiao's avatar
Shucai Xiao committed
104
                auto data_idx  = batch_idx;
105
                // load data to lds and compute the batch max
Shucai Xiao's avatar
Shucai Xiao committed
106
107
108
109
                size_t item_num          = n_dims;
                size_t thread_num        = (n_dims + block_size - 1) / block_size * block_size;
                lds_data[block_size]     = input_ptr[0];
                lds_data[block_size + 1] = 0;
110
                for(size_t i = thr_idx; i < thread_num; i += block_size)
111
                {
Shucai Xiao's avatar
Shucai Xiao committed
112
                    if(i < n_dims)
113
                    {
Shucai Xiao's avatar
Shucai Xiao committed
114
115
                        data_idx[axis]    = i;
                        lds_data[thr_idx] = input_ptr[desc_data.linear(data_idx)];
116
                    }
Khalique's avatar
Khalique committed
117

118
                    __syncthreads();
Khalique's avatar
Khalique committed
119

Shucai Xiao's avatar
Shucai Xiao committed
120
                    auto size = (item_num > block_size) ? block_size : item_num;
121
                    reduce_max<type>(lds_data, block_size, thr_idx, size);
122
123
124
125

                    __syncthreads();

                    item_num -= block_size;
126
                }
Khalique's avatar
Khalique committed
127

Shucai Xiao's avatar
Shucai Xiao committed
128
                item_num = n_dims;
129
                for(size_t i = thr_idx; i < thread_num; i += block_size)
130
                {
Shucai Xiao's avatar
Shucai Xiao committed
131
                    if(i < n_dims)
132
133
                    {
                        data_idx[axis] = i;
Shucai Xiao's avatar
Shucai Xiao committed
134
135
136
                        lds_data[thr_idx] =
                            input_ptr[desc_data.linear(data_idx)] - lds_data[block_size];
                        lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx]));
137
                    }
138
139
140

                    __syncthreads();

Shucai Xiao's avatar
Shucai Xiao committed
141
                    auto size = (item_num > block_size) ? block_size : item_num;
142
                    reduce_sum<type>(lds_data, block_size, thr_idx, size);
143
144
145
                    __syncthreads();

                    item_num -= block_size;
146
                }
Khalique's avatar
Khalique committed
147

148
                for(size_t i = thr_idx; i < n_dims; i += block_size)
149
                {
150
151
152
                    data_idx[axis]    = i;
                    size_t index      = desc_data.linear(data_idx);
                    auto val          = input_ptr[index] - lds_data[block_size];
153
                    output_ptr[index] = ::exp(to_hip_type(val)) / lds_data[block_size + 1];
154
155
                }
            });
Khalique's avatar
Khalique committed
156
157
158
159
160
161
162
163
        });
    });
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx