softmax.cpp 5.54 KB
Newer Older
Khalique's avatar
Khalique committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/softmax.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/hip.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

argument softmax(hipStream_t stream,
Khalique's avatar
Khalique committed
16
17
18
                 const migraphx::shape& output_shape,
                 std::vector<migraphx::argument> args,
                 int axis)
Khalique's avatar
Khalique committed
19
{
Shucai Xiao's avatar
Shucai Xiao committed
20
21
22
    auto lens        = output_shape.lens();
    auto batch_lens  = lens;
    size_t n_dims    = lens[axis];
23
    batch_lens[axis] = 1;
24
    migraphx::shape batch_shape{output_shape.type(), batch_lens};
Khalique's avatar
Khalique committed
25
26
27
28

    visit_all(args.back(), args.front())([&](auto output, auto input) {
        const auto* input_ptr = device_cast(input.data());
        auto* output_ptr      = device_cast(output.data());
29
30
31
        visit_tensor_size(batch_shape.lens().size(), [&](auto n_dim) {
            hip_tensor_descriptor<n_dim> desc_batch(batch_shape);
            hip_tensor_descriptor<n_dim> desc_data(output_shape);
Khalique's avatar
Khalique committed
32

33
            // use one block for items in one batch.
34
35
36
37
38
39
40
            const size_t max_block_size = 1024;
            size_t block_size = 1;
            while (block_size < max_block_size and block_size < n_dims)
            {
                block_size *= 2;
            }

41
42
43
44
45
46
47
48
            launch(
                stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
                size_t thr_idx = idx.local;
                size_t blk_idx = idx.group;
                using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;

                // all data can be loaded to the lds once, so all operations are
                // done in lds
49
                MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 2];
50
                auto batch_idx = desc_batch.multi(blk_idx);
Shucai Xiao's avatar
Shucai Xiao committed
51
                auto data_idx  = batch_idx;
52
53
                // load data to lds and compute the batch max
                size_t item_num      = n_dims;
54
                size_t thread_num = (n_dims + block_size - 1) / block_size * block_size;
55
                lds_data[block_size] = input_ptr[0];
56
57
                lds_data[block_size + 1]    = 0;
                for(size_t i = thr_idx; i < thread_num; i += block_size)
58
                {
59
60
61
62
63
                    if (i < n_dims)
                    {
                        data_idx[axis] = i;
                        lds_data[thr_idx]    = input_ptr[desc_data.linear(data_idx)];
                    }
Khalique's avatar
Khalique committed
64

65
                    __syncthreads();
Khalique's avatar
Khalique committed
66

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
                    auto size   = (item_num > block_size) ? block_size : item_num;
                    auto stride = (size + 1) / 2;
                    while(true)
                    {
                        if(thr_idx + stride < size)
                        {
                            lds_data[thr_idx] = ::max(to_hip_type(lds_data[thr_idx]),
                                                      to_hip_type(lds_data[thr_idx + stride]));
                        }
                        __syncthreads();
                        size   = stride;
                        stride = (stride + 1) / 2;

                        if(size == 1)
                            break;
                    }

                    if(thr_idx == 0)
                    {
                        lds_data[block_size] = (lds_data[0] < lds_data[block_size])
                                                   ? lds_data[block_size]
                                                   : lds_data[0];
                    }
                    __syncthreads();

                    item_num -= block_size;
93
                }
Khalique's avatar
Khalique committed
94

95
                item_num                 = n_dims;
96
                for(size_t i = thr_idx; i < thread_num; i += block_size)
97
                {
98
99
100
101
102
103
                    if (i < n_dims)
                    {
                        data_idx[axis] = i;
                        lds_data[thr_idx]    = input_ptr[desc_data.linear(data_idx)] - lds_data[block_size];
                        lds_data[thr_idx]    = ::exp(to_hip_type(lds_data[thr_idx]));
                    }
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123

                    __syncthreads();

                    auto size   = (item_num > block_size) ? block_size : item_num;
                    auto stride = (size + 1) / 2;
                    while(true)
                    {
                        if(thr_idx + stride < size)
                        {
                            lds_data[thr_idx] += lds_data[thr_idx + stride];
                        }
                        __syncthreads();
                        size   = stride;
                        stride = (stride + 1) / 2;
                        if(size == 1)
                            break;
                    }

                    if(thr_idx == 0)
                    {
124
                        lds_data[block_size + 1] += lds_data[0];
125
126
127
128
                    }
                    __syncthreads();

                    item_num -= block_size;
129
                }
Khalique's avatar
Khalique committed
130

131
                for(size_t i = thr_idx; i < n_dims; i += block_size)
132
                {
133
134
135
                    data_idx[axis]    = i;
                    size_t index      = desc_data.linear(data_idx);
                    auto val          = input_ptr[index] - lds_data[block_size];
136
                    output_ptr[index] = ::exp(to_hip_type(val)) / lds_data[block_size + 1];
137
138
                }
            });
Khalique's avatar
Khalique committed
139
140
141
142
143
144
145
146
147
148
        });
    });

    return args.back();
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx