softmax.cpp 5.51 KB
Newer Older
Khalique's avatar
Khalique committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/softmax.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/hip.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

15
16
17
void softmax(hipStream_t stream,
                 const argument& result,
                 const argument& arg,
Khalique's avatar
Khalique committed
18
                 int axis)
Khalique's avatar
Khalique committed
19
{
20
    auto lens        = result.get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
21
22
    auto batch_lens  = lens;
    size_t n_dims    = lens[axis];
23
    batch_lens[axis] = 1;
24
    migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
Khalique's avatar
Khalique committed
25

26
    visit_all(result, arg)([&](auto output, auto input) {
Khalique's avatar
Khalique committed
27
28
        const auto* input_ptr = device_cast(input.data());
        auto* output_ptr      = device_cast(output.data());
29
30
        visit_tensor_size(batch_shape.lens().size(), [&](auto n_dim) {
            hip_tensor_descriptor<n_dim> desc_batch(batch_shape);
31
            hip_tensor_descriptor<n_dim> desc_data(result.get_shape());
Khalique's avatar
Khalique committed
32

33
            // use one block for items in one batch.
34
            const size_t max_block_size = 1024;
Shucai Xiao's avatar
Shucai Xiao committed
35
36
            size_t block_size           = 1;
            while(block_size < max_block_size and block_size < n_dims)
37
38
39
40
            {
                block_size *= 2;
            }

41
42
43
44
45
46
47
48
            launch(
                stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
                size_t thr_idx = idx.local;
                size_t blk_idx = idx.group;
                using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;

                // all data can be loaded to the lds once, so all operations are
                // done in lds
49
                MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 2];
50
                auto batch_idx = desc_batch.multi(blk_idx);
Shucai Xiao's avatar
Shucai Xiao committed
51
                auto data_idx  = batch_idx;
52
                // load data to lds and compute the batch max
Shucai Xiao's avatar
Shucai Xiao committed
53
54
55
56
                size_t item_num          = n_dims;
                size_t thread_num        = (n_dims + block_size - 1) / block_size * block_size;
                lds_data[block_size]     = input_ptr[0];
                lds_data[block_size + 1] = 0;
57
                for(size_t i = thr_idx; i < thread_num; i += block_size)
58
                {
Shucai Xiao's avatar
Shucai Xiao committed
59
                    if(i < n_dims)
60
                    {
Shucai Xiao's avatar
Shucai Xiao committed
61
62
                        data_idx[axis]    = i;
                        lds_data[thr_idx] = input_ptr[desc_data.linear(data_idx)];
63
                    }
Khalique's avatar
Khalique committed
64

65
                    __syncthreads();
Khalique's avatar
Khalique committed
66

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
                    auto size   = (item_num > block_size) ? block_size : item_num;
                    auto stride = (size + 1) / 2;
                    while(true)
                    {
                        if(thr_idx + stride < size)
                        {
                            lds_data[thr_idx] = ::max(to_hip_type(lds_data[thr_idx]),
                                                      to_hip_type(lds_data[thr_idx + stride]));
                        }
                        __syncthreads();
                        size   = stride;
                        stride = (stride + 1) / 2;

                        if(size == 1)
                            break;
                    }

                    if(thr_idx == 0)
                    {
                        lds_data[block_size] = (lds_data[0] < lds_data[block_size])
                                                   ? lds_data[block_size]
                                                   : lds_data[0];
                    }
                    __syncthreads();

                    item_num -= block_size;
93
                }
Khalique's avatar
Khalique committed
94

Shucai Xiao's avatar
Shucai Xiao committed
95
                item_num = n_dims;
96
                for(size_t i = thr_idx; i < thread_num; i += block_size)
97
                {
Shucai Xiao's avatar
Shucai Xiao committed
98
                    if(i < n_dims)
99
100
                    {
                        data_idx[axis] = i;
Shucai Xiao's avatar
Shucai Xiao committed
101
102
103
                        lds_data[thr_idx] =
                            input_ptr[desc_data.linear(data_idx)] - lds_data[block_size];
                        lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx]));
104
                    }
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124

                    __syncthreads();

                    auto size   = (item_num > block_size) ? block_size : item_num;
                    auto stride = (size + 1) / 2;
                    while(true)
                    {
                        if(thr_idx + stride < size)
                        {
                            lds_data[thr_idx] += lds_data[thr_idx + stride];
                        }
                        __syncthreads();
                        size   = stride;
                        stride = (stride + 1) / 2;
                        if(size == 1)
                            break;
                    }

                    if(thr_idx == 0)
                    {
125
                        lds_data[block_size + 1] += lds_data[0];
126
127
128
129
                    }
                    __syncthreads();

                    item_num -= block_size;
130
                }
Khalique's avatar
Khalique committed
131

132
                for(size_t i = thr_idx; i < n_dims; i += block_size)
133
                {
134
135
136
                    data_idx[axis]    = i;
                    size_t index      = desc_data.linear(data_idx);
                    auto val          = input_ptr[index] - lds_data[block_size];
137
                    output_ptr[index] = ::exp(to_hip_type(val)) / lds_data[block_size + 1];
138
139
                }
            });
Khalique's avatar
Khalique committed
140
141
142
143
144
145
146
147
        });
    });
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx