softmax.cpp 5.12 KB
Newer Older
Khalique's avatar
Khalique committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/softmax.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/hip.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

argument softmax(hipStream_t stream,
Khalique's avatar
Khalique committed
16
17
18
                 const migraphx::shape& output_shape,
                 std::vector<migraphx::argument> args,
                 int axis)
Khalique's avatar
Khalique committed
19
{
Shucai Xiao's avatar
Shucai Xiao committed
20
21
22
    auto lens        = output_shape.lens();
    auto batch_lens  = lens;
    size_t n_dims    = lens[axis];
23
24
    batch_lens[axis] = 1;
    migraphx::shape batch_shape{shape::int32_type, batch_lens};
Khalique's avatar
Khalique committed
25
26
27
28

    visit_all(args.back(), args.front())([&](auto output, auto input) {
        const auto* input_ptr = device_cast(input.data());
        auto* output_ptr      = device_cast(output.data());
29
30
31
        visit_tensor_size(batch_shape.lens().size(), [&](auto n_dim) {
            hip_tensor_descriptor<n_dim> desc_batch(batch_shape);
            hip_tensor_descriptor<n_dim> desc_data(output_shape);
Khalique's avatar
Khalique committed
32

33
34
35
36
37
38
39
40
41
42
43
44
            // use one block for items in one batch.
            const size_t block_size = 1024;
            launch(
                stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
                size_t thr_idx = idx.local;
                size_t blk_idx = idx.group;
                using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;

                // all data can be loaded to the lds once, so all operations are
                // done in lds
                MIGRAPHX_DEVICE_SHARED type lds_data[block_size + 2];
                auto batch_idx = desc_batch.multi(blk_idx);
Shucai Xiao's avatar
Shucai Xiao committed
45
                auto data_idx  = batch_idx;
46
47
48
49
                // load data to lds and compute the batch max
                size_t item_num      = n_dims;
                lds_data[block_size] = input_ptr[0];
                for(size_t i = thr_idx; i < n_dims; i += block_size)
50
                {
51
52
                    data_idx[axis] = i;
                    lds_data[i]    = input_ptr[desc_data.linear(data_idx)];
Khalique's avatar
Khalique committed
53

54
                    __syncthreads();
Khalique's avatar
Khalique committed
55

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
                    auto size   = (item_num > block_size) ? block_size : item_num;
                    auto stride = (size + 1) / 2;
                    while(true)
                    {
                        if(thr_idx + stride < size)
                        {
                            lds_data[thr_idx] = ::max(to_hip_type(lds_data[thr_idx]),
                                                      to_hip_type(lds_data[thr_idx + stride]));
                        }
                        __syncthreads();
                        size   = stride;
                        stride = (stride + 1) / 2;

                        if(size == 1)
                            break;
                    }

                    if(thr_idx == 0)
                    {
                        lds_data[block_size] = (lds_data[0] < lds_data[block_size])
                                                   ? lds_data[block_size]
                                                   : lds_data[0];
                    }
                    __syncthreads();

                    item_num -= block_size;
82
                }
Khalique's avatar
Khalique committed
83

84
85
86
87
                const size_t block_size1 = block_size + 1;
                lds_data[block_size1]    = 0;
                item_num                 = n_dims;
                for(size_t i = thr_idx; i < n_dims; i += block_size)
88
                {
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
                    data_idx[axis] = i;
                    lds_data[i]    = input_ptr[desc_data.linear(data_idx)] - lds_data[block_size];
                    lds_data[i]    = ::exp(to_hip_type(lds_data[i]));

                    __syncthreads();

                    auto size   = (item_num > block_size) ? block_size : item_num;
                    auto stride = (size + 1) / 2;
                    while(true)
                    {
                        if(thr_idx + stride < size)
                        {
                            lds_data[thr_idx] += lds_data[thr_idx + stride];
                        }
                        __syncthreads();
                        size   = stride;
                        stride = (stride + 1) / 2;
                        if(size == 1)
                            break;
                    }

                    if(thr_idx == 0)
                    {
                        lds_data[block_size1] += lds_data[0];
                    }
                    __syncthreads();

                    item_num -= block_size;
117
                }
Khalique's avatar
Khalique committed
118

119
                for(size_t i = thr_idx; i < n_dims; i += block_size)
120
                {
121
122
123
124
                    data_idx[axis]    = i;
                    size_t index      = desc_data.linear(data_idx);
                    auto val          = input_ptr[index] - lds_data[block_size];
                    output_ptr[index] = ::exp(to_hip_type(val)) / lds_data[block_size1];
125
126
                }
            });
Khalique's avatar
Khalique committed
127
128
129
130
131
132
133
134
135
136
        });
    });

    return args.back();
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx