softmax.cpp 5.46 KB
Newer Older
Khalique's avatar
Khalique committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/softmax.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/hip.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

Shucai Xiao's avatar
Shucai Xiao committed
15
void softmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
Khalique's avatar
Khalique committed
16
{
17
    auto lens        = result.get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
18
19
    auto batch_lens  = lens;
    size_t n_dims    = lens[axis];
20
    batch_lens[axis] = 1;
21
    migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
Khalique's avatar
Khalique committed
22

23
    visit_all(result, arg)([&](auto output, auto input) {
Khalique's avatar
Khalique committed
24
25
        const auto* input_ptr = device_cast(input.data());
        auto* output_ptr      = device_cast(output.data());
26
27
        visit_tensor_size(batch_shape.lens().size(), [&](auto n_dim) {
            hip_tensor_descriptor<n_dim> desc_batch(batch_shape);
28
            hip_tensor_descriptor<n_dim> desc_data(result.get_shape());
Khalique's avatar
Khalique committed
29

30
            // use one block for items in one batch.
31
            const size_t max_block_size = 1024;
Shucai Xiao's avatar
Shucai Xiao committed
32
33
            size_t block_size           = 1;
            while(block_size < max_block_size and block_size < n_dims)
34
35
36
37
            {
                block_size *= 2;
            }

38
39
40
41
42
43
44
45
            launch(
                stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
                size_t thr_idx = idx.local;
                size_t blk_idx = idx.group;
                using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;

                // all data can be loaded to the lds once, so all operations are
                // done in lds
46
                MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 2];
47
                auto batch_idx = desc_batch.multi(blk_idx);
Shucai Xiao's avatar
Shucai Xiao committed
48
                auto data_idx  = batch_idx;
49
                // load data to lds and compute the batch max
Shucai Xiao's avatar
Shucai Xiao committed
50
51
52
53
                size_t item_num          = n_dims;
                size_t thread_num        = (n_dims + block_size - 1) / block_size * block_size;
                lds_data[block_size]     = input_ptr[0];
                lds_data[block_size + 1] = 0;
54
                for(size_t i = thr_idx; i < thread_num; i += block_size)
55
                {
Shucai Xiao's avatar
Shucai Xiao committed
56
                    if(i < n_dims)
57
                    {
Shucai Xiao's avatar
Shucai Xiao committed
58
59
                        data_idx[axis]    = i;
                        lds_data[thr_idx] = input_ptr[desc_data.linear(data_idx)];
60
                    }
Khalique's avatar
Khalique committed
61

62
                    __syncthreads();
Khalique's avatar
Khalique committed
63

64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
                    auto size   = (item_num > block_size) ? block_size : item_num;
                    auto stride = (size + 1) / 2;
                    while(true)
                    {
                        if(thr_idx + stride < size)
                        {
                            lds_data[thr_idx] = ::max(to_hip_type(lds_data[thr_idx]),
                                                      to_hip_type(lds_data[thr_idx + stride]));
                        }
                        __syncthreads();
                        size   = stride;
                        stride = (stride + 1) / 2;

                        if(size == 1)
                            break;
                    }

                    if(thr_idx == 0)
                    {
                        lds_data[block_size] = (lds_data[0] < lds_data[block_size])
                                                   ? lds_data[block_size]
                                                   : lds_data[0];
                    }
                    __syncthreads();

                    item_num -= block_size;
90
                }
Khalique's avatar
Khalique committed
91

Shucai Xiao's avatar
Shucai Xiao committed
92
                item_num = n_dims;
93
                for(size_t i = thr_idx; i < thread_num; i += block_size)
94
                {
Shucai Xiao's avatar
Shucai Xiao committed
95
                    if(i < n_dims)
96
97
                    {
                        data_idx[axis] = i;
Shucai Xiao's avatar
Shucai Xiao committed
98
99
100
                        lds_data[thr_idx] =
                            input_ptr[desc_data.linear(data_idx)] - lds_data[block_size];
                        lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx]));
101
                    }
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121

                    __syncthreads();

                    auto size   = (item_num > block_size) ? block_size : item_num;
                    auto stride = (size + 1) / 2;
                    while(true)
                    {
                        if(thr_idx + stride < size)
                        {
                            lds_data[thr_idx] += lds_data[thr_idx + stride];
                        }
                        __syncthreads();
                        size   = stride;
                        stride = (stride + 1) / 2;
                        if(size == 1)
                            break;
                    }

                    if(thr_idx == 0)
                    {
122
                        lds_data[block_size + 1] += lds_data[0];
123
124
125
126
                    }
                    __syncthreads();

                    item_num -= block_size;
127
                }
Khalique's avatar
Khalique committed
128

129
                for(size_t i = thr_idx; i < n_dims; i += block_size)
130
                {
131
132
133
                    data_idx[axis]    = i;
                    size_t index      = desc_data.linear(data_idx);
                    auto val          = input_ptr[index] - lds_data[block_size];
134
                    output_ptr[index] = ::exp(to_hip_type(val)) / lds_data[block_size + 1];
135
136
                }
            });
Khalique's avatar
Khalique committed
137
138
139
140
141
142
143
144
        });
    });
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx