softmax.cpp 2.21 KB
Newer Older
Khalique's avatar
Khalique committed
1
2
3
4
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/softmax.hpp>
5
#include <migraphx/gpu/device/reduce.hpp>
Khalique's avatar
Khalique committed
6
7
8
9
10
11
12
13
14
15
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/hip.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

16
void softmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
Khalique's avatar
Khalique committed
17
{
Shucai Xiao's avatar
Shucai Xiao committed
18
19
    auto lens                  = result.get_shape().lens();
    auto batch_lens            = lens;
Shucai Xiao's avatar
Shucai Xiao committed
20
    std::size_t batch_item_num = lens[axis];
Shucai Xiao's avatar
Shucai Xiao committed
21
    batch_lens[axis]           = 1;
22
    migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
Khalique's avatar
Khalique committed
23

Paul's avatar
Paul committed
24
    hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
25
26
        const std::size_t max_block_size = 256;
        const std::size_t block_size     = compute_block_size(batch_item_num, max_block_size);
Shucai Xiao's avatar
Shucai Xiao committed
27
28
29
        gs_launch(stream,
                  batch_shape.elements() * block_size,
                  block_size)([=](auto i, auto idx) __device__ {
30
            auto data_idx = batch.multi(i / block_size);
Shucai Xiao's avatar
Shucai Xiao committed
31
32
            using type    = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
            type init     = lowest();
33

Shucai Xiao's avatar
Shucai Xiao committed
34
35
36
37
38
            auto batch_max = block_reduce<max_block_size>(
                idx, max{}, init, batch_item_num, [&](auto j) __device__ {
                    data_idx[axis] = j;
                    return input[data_idx];
                });
39

Shucai Xiao's avatar
Shucai Xiao committed
40
41
42
43
44
45
            auto batch_sum =
                block_reduce<max_block_size>(idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
                    data_idx[axis] = j;
                    auto val       = input[data_idx] - batch_max;
                    return ::exp(to_hip_type(val));
                });
46

Shucai Xiao's avatar
Shucai Xiao committed
47
48
49
            idx.local_stride(batch_item_num, [&](auto j) {
                data_idx[axis]   = j;
                auto val         = input[data_idx] - batch_max;
Shucai Xiao's avatar
Shucai Xiao committed
50
                output[data_idx] = ::exp(to_hip_type(val)) / batch_sum;
51
            });
Khalique's avatar
Khalique committed
52
53
54
55
56
57
58
59
        });
    });
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx