argmax.cpp 3.12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/argmax.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/hip.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

Shucai Xiao's avatar
Shucai Xiao committed
15
argument argmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
16
17
18
19
20
21
22
23
24
25
26
27
28
29
{
    auto lens        = arg.get_shape().lens();
    auto batch_lens  = lens;
    size_t n_dims    = lens[axis];
    batch_lens[axis] = 1;
    migraphx::shape batch_shape{shape::float_type, batch_lens};

    visit_all(result, arg)([&](auto output, auto input) {
        const auto* input_ptr = device_cast(input.data());
        auto* output_ptr      = device_cast(output.data());
        visit_tensor_size(batch_shape.lens().size(), [&](auto n_dim) {
            hip_tensor_descriptor<n_dim> desc_batch(batch_shape);
            hip_tensor_descriptor<n_dim> desc_data(arg.get_shape());

Shucai Xiao's avatar
Shucai Xiao committed
30
            // each block is for one batch
31
            const size_t block_size = 1024;
Shucai Xiao's avatar
Shucai Xiao committed
32
33
            launch(
                stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
34
35
36
37
38
                size_t thr_idx = idx.local;
                size_t blk_idx = idx.group;
                using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;

                auto batch_idx = desc_batch.multi(blk_idx);
Shucai Xiao's avatar
Shucai Xiao committed
39
                auto data_idx  = batch_idx;
40
41
42
43
                MIGRAPHX_DEVICE_SHARED type lds_data[block_size];
                MIGRAPHX_DEVICE_SHARED int64_t lds_index[block_size];
                // load data to lds_data
                size_t item_num = n_dims;
Shucai Xiao's avatar
Shucai Xiao committed
44
                for(size_t i = thr_idx; i < n_dims; i += block_size)
45
                {
Shucai Xiao's avatar
Shucai Xiao committed
46
                    data_idx[axis]     = i;
47
                    lds_index[thr_idx] = i;
Shucai Xiao's avatar
Shucai Xiao committed
48
                    lds_data[thr_idx]  = input_ptr[desc_data.linear(data_idx)];
49
50
                    __syncthreads();

Shucai Xiao's avatar
Shucai Xiao committed
51
                    auto size   = (item_num > block_size) ? block_size : item_num;
52
                    auto stride = (size + 1) / 2;
Shucai Xiao's avatar
Shucai Xiao committed
53
                    while(true)
54
                    {
Shucai Xiao's avatar
Shucai Xiao committed
55
56
                        if(thr_idx + stride < size and
                           lds_data[thr_idx] < lds_data[thr_idx + stride])
57
                        {
Shucai Xiao's avatar
Shucai Xiao committed
58
                            lds_data[thr_idx]  = lds_data[thr_idx + stride];
59
60
61
62
                            lds_index[thr_idx] = lds_index[thr_idx + stride];
                        }

                        __syncthreads();
Shucai Xiao's avatar
Shucai Xiao committed
63
                        size   = stride;
64
65
                        stride = (stride + 1) / 2;

Shucai Xiao's avatar
Shucai Xiao committed
66
67
                        if(size == 1)
                            break;
68
69
                    }

Shucai Xiao's avatar
Shucai Xiao committed
70
                    if(thr_idx == 0)
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
                    {
                        output_ptr[blk_idx] = lds_index[0];
                    }

                    item_num -= block_size;
                }
            });
        });
    });

    return args.back();
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx