argmax.cpp 816 Bytes
Newer Older
1
2
3
4
5
6
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/argmax.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
7
#include <migraphx/gpu/device/arg_op.hpp>
8
9
10
11
12
13
14
#include <migraphx/gpu/hip.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

Shucai Xiao's avatar
Shucai Xiao committed
15
void argmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
16
{
17
    arg.visit([&](auto input) {
Shucai Xiao's avatar
Shucai Xiao committed
18
        using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
19
        arg_op<pair_max<type, int64_t>>(pair_max<type, int64_t>{}, stream, result, arg, axis);
20
21
22
23
24
25
26
    });
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx