softmax.cpp 3.73 KB
Newer Older
Khalique's avatar
Khalique committed
1
2
3
4
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/softmax.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
5
#include <migraphx/gpu/device/reduce_opers.hpp>
Khalique's avatar
Khalique committed
6
7
8
9
10
11
12
13
14
15
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/hip.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

16
void softmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
Khalique's avatar
Khalique committed
17
{
Shucai Xiao's avatar
Shucai Xiao committed
18
19
    auto lens                  = result.get_shape().lens();
    auto batch_lens            = lens;
Shucai Xiao's avatar
Shucai Xiao committed
20
    std::size_t batch_item_num = lens[axis];
Shucai Xiao's avatar
Shucai Xiao committed
21
    batch_lens[axis]           = 1;
22
    migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
Khalique's avatar
Khalique committed
23

Paul's avatar
Paul committed
24
    hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
25
        // use one block for items in one batch.
Shucai Xiao's avatar
Shucai Xiao committed
26
27
        const std::size_t max_block_size = 1024;
        std::size_t block_size           = 1;
28
29
30
31
        while(block_size < max_block_size and block_size < batch_item_num)
        {
            block_size *= 2;
        }
Khalique's avatar
Khalique committed
32

Shucai Xiao's avatar
Shucai Xiao committed
33
        launch(stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
Shucai Xiao's avatar
Shucai Xiao committed
34
35
            std::size_t thr_idx = idx.local;
            std::size_t blk_idx = idx.group;
Shucai Xiao's avatar
Shucai Xiao committed
36
            using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
Khalique's avatar
Khalique committed
37

38
39
40
41
            MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 1];
            auto batch_idx = batch.multi(blk_idx);
            auto data_idx  = batch_idx;
            // load data to lds and compute the batch max
Shucai Xiao's avatar
Shucai Xiao committed
42
            std::size_t remaining_item_num = batch_item_num;
Shucai Xiao's avatar
Shucai Xiao committed
43
44
45
            std::size_t round_item_num =
                (batch_item_num + block_size - 1) / block_size * block_size;
            lds_data[max_block_size] = input[0];
Shucai Xiao's avatar
Shucai Xiao committed
46
            for(std::size_t i = thr_idx; i < round_item_num; i += block_size)
47
            {
48
                if(i < batch_item_num)
49
                {
50
                    data_idx[axis]    = i;
Shucai Xiao's avatar
Shucai Xiao committed
51
                    lds_data[thr_idx] = input[data_idx];
52
                }
Khalique's avatar
Khalique committed
53

Shucai Xiao's avatar
Shucai Xiao committed
54
55
                __syncthreads();

Shucai Xiao's avatar
Shucai Xiao committed
56
                auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
Shucai Xiao's avatar
Shucai Xiao committed
57
58
                block_reduce<type, max_op<type>>(
                    lds_data, max_op<type>{}, block_size, thr_idx, item_num, max_block_size);
59
                remaining_item_num -= block_size;
Paul's avatar
Paul committed
60
            }
61

62
            auto batch_max = lds_data[max_block_size];
63
            __syncthreads();
64

65
            lds_data[max_block_size] = 0;
Shucai Xiao's avatar
Shucai Xiao committed
66
            remaining_item_num       = batch_item_num;
Shucai Xiao's avatar
Shucai Xiao committed
67
            for(std::size_t i = thr_idx; i < round_item_num; i += block_size)
Paul's avatar
Paul committed
68
            {
69
                if(i < batch_item_num)
70
                {
71
                    data_idx[axis]    = i;
Shucai Xiao's avatar
Shucai Xiao committed
72
                    lds_data[thr_idx] = input[data_idx] - batch_max;
73
                    lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx]));
74
                }
75
76
77

                __syncthreads();

Shucai Xiao's avatar
Shucai Xiao committed
78
                auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
Shucai Xiao's avatar
Shucai Xiao committed
79
80
                block_reduce<type, sum_op<type>>(
                    lds_data, sum_op<type>{}, block_size, thr_idx, item_num, max_block_size);
81
82

                remaining_item_num -= block_size;
Paul's avatar
Paul committed
83
            }
84
            auto batch_sum = lds_data[max_block_size];
Khalique's avatar
Khalique committed
85

Shucai Xiao's avatar
Shucai Xiao committed
86
            for(std::size_t i = thr_idx; i < batch_item_num; i += block_size)
Paul's avatar
Paul committed
87
            {
Shucai Xiao's avatar
Shucai Xiao committed
88
89
90
                data_idx[axis]   = i;
                auto val         = input[data_idx] - batch_max;
                output[data_idx] = ::exp(to_hip_type(val)) / batch_sum;
Paul's avatar
Paul committed
91
            }
Khalique's avatar
Khalique committed
92
93
94
95
96
97
98
99
        });
    });
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx