softmax.cpp 2.17 KB
Newer Older
Khalique's avatar
Khalique committed
1
2
3
4
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/softmax.hpp>
5
#include <migraphx/gpu/device/reduce.hpp>
Khalique's avatar
Khalique committed
6
7
8
9
10
11
12
13
14
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

15
void softmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
Khalique's avatar
Khalique committed
16
{
17
18
19
20
    auto lens                = result.get_shape().lens();
    auto batch_lens          = lens;
    index_int batch_item_num = lens[axis];
    batch_lens[axis]         = 1;
21
    migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
Khalique's avatar
Khalique committed
22

Paul's avatar
Paul committed
23
    hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
24
25
        const index_int max_block_size = 256;
        const index_int block_size     = compute_block_size(batch_item_num, max_block_size);
Shucai Xiao's avatar
Shucai Xiao committed
26
27
28
        gs_launch(stream,
                  batch_shape.elements() * block_size,
                  block_size)([=](auto i, auto idx) __device__ {
29
            auto data_idx = batch.multi(i / block_size);
Shucai Xiao's avatar
Shucai Xiao committed
30
31
            using type    = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
            type init     = lowest();
32

Shucai Xiao's avatar
Shucai Xiao committed
33
34
35
36
37
            auto batch_max = block_reduce<max_block_size>(
                idx, max{}, init, batch_item_num, [&](auto j) __device__ {
                    data_idx[axis] = j;
                    return input[data_idx];
                });
38

Shucai Xiao's avatar
Shucai Xiao committed
39
40
41
42
43
44
            auto batch_sum =
                block_reduce<max_block_size>(idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
                    data_idx[axis] = j;
                    auto val       = input[data_idx] - batch_max;
                    return ::exp(to_hip_type(val));
                });
45

Shucai Xiao's avatar
Shucai Xiao committed
46
47
48
            idx.local_stride(batch_item_num, [&](auto j) {
                data_idx[axis]   = j;
                auto val         = input[data_idx] - batch_max;
Shucai Xiao's avatar
Shucai Xiao committed
49
                output[data_idx] = ::exp(to_hip_type(val)) / batch_sum;
50
            });
Khalique's avatar
Khalique committed
51
52
53
54
55
56
57
58
        });
    });
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx