"tests/pipelines/lumina/__init__.py" did not exist on "6ab2dd18a4d17d90c92409886ac22a02acf25d7d"
argmax.cpp 647 Bytes
Newer Older
1
2
3
4
5
6
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/argmax.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
7
#include <migraphx/gpu/device/arg_op.hpp>
8
9
10
11
12
13
14
#include <migraphx/gpu/hip.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {

15
void argmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
16
{
Shucai Xiao's avatar
Shucai Xiao committed
17
    arg_op(argmax_op{}, stream, result, arg, axis);
18
19
20
21
22
23
}

} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx