softmax.cpp 4.16 KB
Newer Older
Khalique's avatar
Khalique committed
1
2
3
4
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/softmax.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
5
#include <migraphx/gpu/device/reduce_opers.hpp>
Khalique's avatar
Khalique committed
6
7
8
9
10
11
12
13
14
15
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/hip.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

Shucai Xiao's avatar
Shucai Xiao committed
16
void softmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
Khalique's avatar
Khalique committed
17
{
Shucai Xiao's avatar
Shucai Xiao committed
18
19
20
21
    auto lens             = result.get_shape().lens();
    auto batch_lens       = lens;
    size_t batch_item_num = lens[axis];
    batch_lens[axis]      = 1;
22
    migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
Khalique's avatar
Khalique committed
23

24
    visit_all(result, arg)([&](auto output, auto input) {
Khalique's avatar
Khalique committed
25
26
        const auto* input_ptr = device_cast(input.data());
        auto* output_ptr      = device_cast(output.data());
27
28
        visit_tensor_size(batch_shape.lens().size(), [&](auto n_dim) {
            hip_tensor_descriptor<n_dim> desc_batch(batch_shape);
29
            hip_tensor_descriptor<n_dim> desc_data(result.get_shape());
Khalique's avatar
Khalique committed
30

31
            // use one block for items in one batch.
32
            const size_t max_block_size = 1024;
Shucai Xiao's avatar
Shucai Xiao committed
33
            size_t block_size           = 1;
Shucai Xiao's avatar
Shucai Xiao committed
34
            while(block_size < max_block_size and block_size < batch_item_num)
35
36
37
38
            {
                block_size *= 2;
            }

39
40
41
42
43
44
            launch(
                stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
                size_t thr_idx = idx.local;
                size_t blk_idx = idx.group;
                using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;

Shucai Xiao's avatar
Shucai Xiao committed
45
                MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 1];
46
                auto batch_idx = desc_batch.multi(blk_idx);
Shucai Xiao's avatar
Shucai Xiao committed
47
                auto data_idx  = batch_idx;
48
                // load data to lds and compute the batch max
Shucai Xiao's avatar
Shucai Xiao committed
49
50
51
                size_t remaining_item_num = batch_item_num;
                size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size;
                lds_data[block_size]  = input_ptr[0];
Shucai Xiao's avatar
Shucai Xiao committed
52
                for(size_t i = thr_idx; i < round_item_num; i += block_size)
53
                {
Shucai Xiao's avatar
Shucai Xiao committed
54
                    if(i < batch_item_num)
55
                    {
Shucai Xiao's avatar
Shucai Xiao committed
56
57
                        data_idx[axis]    = i;
                        lds_data[thr_idx] = input_ptr[desc_data.linear(data_idx)];
58
                    }
Khalique's avatar
Khalique committed
59

60
                    __syncthreads();
Khalique's avatar
Khalique committed
61

Shucai Xiao's avatar
Shucai Xiao committed
62
63
                    auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
                    reduce_max<type>(lds_data, block_size, thr_idx, item_num);
64

Shucai Xiao's avatar
Shucai Xiao committed
65
                    remaining_item_num -= block_size;
66
                }
Khalique's avatar
Khalique committed
67

Shucai Xiao's avatar
Shucai Xiao committed
68
69
70
71
                auto batch_max = lds_data[block_size];
                __syncthreads();

                lds_data[block_size] = 0;
Shucai Xiao's avatar
Shucai Xiao committed
72
                remaining_item_num   = batch_item_num;
Shucai Xiao's avatar
Shucai Xiao committed
73
                for(size_t i = thr_idx; i < round_item_num; i += block_size)
74
                {
Shucai Xiao's avatar
Shucai Xiao committed
75
                    if(i < batch_item_num)
76
                    {
Shucai Xiao's avatar
Shucai Xiao committed
77
78
                        data_idx[axis]    = i;
                        lds_data[thr_idx] = input_ptr[desc_data.linear(data_idx)] - batch_max;
Shucai Xiao's avatar
Shucai Xiao committed
79
                        lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx]));
80
                    }
81
82
83

                    __syncthreads();

Shucai Xiao's avatar
Shucai Xiao committed
84
85
                    auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
                    reduce_sum<type>(lds_data, block_size, thr_idx, item_num);
86

Shucai Xiao's avatar
Shucai Xiao committed
87
                    remaining_item_num -= block_size;
88
                }
Shucai Xiao's avatar
Shucai Xiao committed
89
                auto batch_sum = lds_data[block_size];
Khalique's avatar
Khalique committed
90

Shucai Xiao's avatar
Shucai Xiao committed
91
                for(size_t i = thr_idx; i < batch_item_num; i += block_size)
92
                {
93
94
                    data_idx[axis]    = i;
                    size_t index      = desc_data.linear(data_idx);
Shucai Xiao's avatar
Shucai Xiao committed
95
96
                    auto val          = input_ptr[index] - batch_max;
                    output_ptr[index] = ::exp(to_hip_type(val)) / batch_sum;
97
98
                }
            });
Khalique's avatar
Khalique committed
99
100
101
102
103
104
105
106
        });
    });
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx