softmax.cpp 3.55 KB
Newer Older
Khalique's avatar
Khalique committed
1
2
3
4
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/softmax.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
5
#include <migraphx/gpu/device/reduce_opers.hpp>
Khalique's avatar
Khalique committed
6
7
8
9
10
11
12
13
14
15
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/hip.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

16
void softmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
Khalique's avatar
Khalique committed
17
{
Shucai Xiao's avatar
Shucai Xiao committed
18
19
20
21
    auto lens             = result.get_shape().lens();
    auto batch_lens       = lens;
    size_t batch_item_num = lens[axis];
    batch_lens[axis]      = 1;
22
    migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
Khalique's avatar
Khalique committed
23

Paul's avatar
Paul committed
24
    hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
25
26
27
28
29
30
31
        // use one block for items in one batch.
        const size_t max_block_size = 1024;
        size_t block_size           = 1;
        while(block_size < max_block_size and block_size < batch_item_num)
        {
            block_size *= 2;
        }
Khalique's avatar
Khalique committed
32

Shucai Xiao's avatar
Shucai Xiao committed
33
        launch(stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
34
35
            size_t thr_idx = idx.local;
            size_t blk_idx = idx.group;
Shucai Xiao's avatar
Shucai Xiao committed
36
            using type     = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
Khalique's avatar
Khalique committed
37

38
39
40
41
42
            MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 1];
            auto batch_idx = batch.multi(blk_idx);
            auto data_idx  = batch_idx;
            // load data to lds and compute the batch max
            size_t remaining_item_num = batch_item_num;
Shucai Xiao's avatar
Shucai Xiao committed
43
            size_t round_item_num     = (batch_item_num + block_size - 1) / block_size * block_size;
Shucai Xiao's avatar
Shucai Xiao committed
44
            lds_data[max_block_size]  = input[0];
45
            for(size_t i = thr_idx; i < round_item_num; i += block_size)
46
            {
47
                if(i < batch_item_num)
48
                {
49
                    data_idx[axis]    = i;
Shucai Xiao's avatar
Shucai Xiao committed
50
                    lds_data[thr_idx] = input[data_idx];
51
                }
Khalique's avatar
Khalique committed
52

Shucai Xiao's avatar
Shucai Xiao committed
53
54
                __syncthreads();

Shucai Xiao's avatar
Shucai Xiao committed
55
                auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
56
                reduce_max(lds_data, block_size, thr_idx, item_num, max_block_size);
57

58
                remaining_item_num -= block_size;
Paul's avatar
Paul committed
59
            }
60

61
            auto batch_max = lds_data[max_block_size];
62
            __syncthreads();
63

64
            lds_data[max_block_size] = 0;
Shucai Xiao's avatar
Shucai Xiao committed
65
            remaining_item_num       = batch_item_num;
66
            for(size_t i = thr_idx; i < round_item_num; i += block_size)
Paul's avatar
Paul committed
67
            {
68
                if(i < batch_item_num)
69
                {
70
                    data_idx[axis]    = i;
Shucai Xiao's avatar
Shucai Xiao committed
71
                    lds_data[thr_idx] = input[data_idx] - batch_max;
72
                    lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx]));
73
                }
74
75
76

                __syncthreads();

Shucai Xiao's avatar
Shucai Xiao committed
77
                auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
78
                reduce_sum(lds_data, block_size, thr_idx, item_num, max_block_size);
79
80

                remaining_item_num -= block_size;
Paul's avatar
Paul committed
81
            }
82
            auto batch_sum = lds_data[max_block_size];
Khalique's avatar
Khalique committed
83

84
            for(size_t i = thr_idx; i < batch_item_num; i += block_size)
Paul's avatar
Paul committed
85
            {
Shucai Xiao's avatar
Shucai Xiao committed
86
87
88
                data_idx[axis]   = i;
                auto val         = input[data_idx] - batch_max;
                output[data_idx] = ::exp(to_hip_type(val)) / batch_sum;
Paul's avatar
Paul committed
89
            }
Khalique's avatar
Khalique committed
90
91
92
93
94
95
96
97
        });
    });
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx