Commit e69e915b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change axis attribute from int to int64_t

parent ee46bc9f
...@@ -12,7 +12,7 @@ namespace op { ...@@ -12,7 +12,7 @@ namespace op {
struct argmax struct argmax
{ {
int axis = 0; int64_t axis = 0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -26,7 +26,7 @@ struct argmax ...@@ -26,7 +26,7 @@ struct argmax
{ {
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
int n_dim = static_cast<int>(lens.size()); int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < 0) if(axis >= n_dim || axis < 0)
{ {
MIGRAPHX_THROW("ARGMAX: axis is out of range."); MIGRAPHX_THROW("ARGMAX: axis is out of range.");
......
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP #define MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
//#include <array>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
//#include <migraphx/stringutils.hpp>
//#include <migraphx/literal.hpp>
#include <migraphx/par_dfor.hpp> #include <migraphx/par_dfor.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
//#include <cmath>
//#include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -17,7 +12,7 @@ namespace op { ...@@ -17,7 +12,7 @@ namespace op {
struct argmin struct argmin
{ {
int axis = 0; int64_t axis = 0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -31,7 +26,7 @@ struct argmin ...@@ -31,7 +26,7 @@ struct argmin
{ {
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
int n_dim = static_cast<int>(lens.size()); int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < 0) if(axis >= n_dim || axis < 0)
{ {
MIGRAPHX_THROW("ARGMIN: axis is out of range."); MIGRAPHX_THROW("ARGMIN: axis is out of range.");
......
...@@ -273,10 +273,10 @@ struct onnx_parser ...@@ -273,10 +273,10 @@ struct onnx_parser
const attribute_map& attributes, const attribute_map& attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
{ {
int axis = 0; int64_t axis = 0;
if(contains(attributes, "axis")) if(contains(attributes, "axis"))
{ {
axis = parse_value(attributes.at("axis")).at<int>(); axis = static_cast<int64_t>(parse_value(attributes.at("axis")).at<int>());
} }
int keep_dims = 1; int keep_dims = 1;
...@@ -288,7 +288,7 @@ struct onnx_parser ...@@ -288,7 +288,7 @@ struct onnx_parser
if(keep_dims == 0) if(keep_dims == 0)
{ {
auto ins = prog.add_instruction(op::argmax{axis}, std::move(args)); auto ins = prog.add_instruction(op::argmax{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{static_cast<int64_t>(axis)}}, ins); return prog.add_instruction(op::squeeze{{axis}}, ins);
} }
else else
{ {
...@@ -300,10 +300,10 @@ struct onnx_parser ...@@ -300,10 +300,10 @@ struct onnx_parser
const attribute_map& attributes, const attribute_map& attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
{ {
int axis = 0; int64_t axis = 0;
if(contains(attributes, "axis")) if(contains(attributes, "axis"))
{ {
axis = parse_value(attributes.at("axis")).at<int>(); axis = static_cast<int64_t>(parse_value(attributes.at("axis")).at<int>());
} }
int keep_dims = 1; int keep_dims = 1;
...@@ -315,7 +315,7 @@ struct onnx_parser ...@@ -315,7 +315,7 @@ struct onnx_parser
if(keep_dims == 0) if(keep_dims == 0)
{ {
auto ins = prog.add_instruction(op::argmin{axis}, std::move(args)); auto ins = prog.add_instruction(op::argmin{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{static_cast<int64_t>(axis)}}, ins); return prog.add_instruction(op::squeeze{{axis}}, ins);
} }
else else
{ {
......
...@@ -12,7 +12,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -12,7 +12,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void argmax(hipStream_t stream, const argument& result, const argument& arg, int axis) void argmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{ {
arg_op(argmax_op{}, stream, result, arg, axis); arg_op(argmax_op{}, stream, result, arg, axis);
} }
......
...@@ -12,7 +12,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -12,7 +12,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void argmin(hipStream_t stream, const argument& result, const argument& arg, int axis) void argmin(hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{ {
arg_op(argmin_op{}, stream, result, arg, axis); arg_op(argmin_op{}, stream, result, arg, axis);
} }
......
...@@ -70,7 +70,7 @@ struct argmin_op ...@@ -70,7 +70,7 @@ struct argmin_op
}; };
template <class Op> template <class Op>
void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int axis) void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{ {
auto arg_shape = arg.get_shape(); auto arg_shape = arg.get_shape();
auto lens = arg_shape.lens(); auto lens = arg_shape.lens();
......
...@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void argmax(hipStream_t stream, const argument& result, const argument& arg, int axis); void argmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void argmin(hipStream_t stream, const argument& result, const argument& arg, int axis); void argmin(hipStream_t stream, const argument& result, const argument& arg, int64_t axis);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment