"src/targets/vscode:/vscode.git/clone" did not exist on "d9a5acbddc03cf9c4a5b01869fc3c51f463d36a1"
Commit e69e915b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change axis attribute from int to int64_t

parent ee46bc9f
......@@ -12,7 +12,7 @@ namespace op {
struct argmax
{
int axis = 0;
int64_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -26,7 +26,7 @@ struct argmax
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens();
int n_dim = static_cast<int>(lens.size());
int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < 0)
{
MIGRAPHX_THROW("ARGMAX: axis is out of range.");
......
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
//#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
//#include <migraphx/stringutils.hpp>
//#include <migraphx/literal.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/config.hpp>
//#include <cmath>
//#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -17,7 +12,7 @@ namespace op {
struct argmin
{
int axis = 0;
int64_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -31,7 +26,7 @@ struct argmin
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens();
int n_dim = static_cast<int>(lens.size());
int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < 0)
{
MIGRAPHX_THROW("ARGMIN: axis is out of range.");
......
......@@ -273,10 +273,10 @@ struct onnx_parser
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int axis = 0;
int64_t axis = 0;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
axis = static_cast<int64_t>(parse_value(attributes.at("axis")).at<int>());
}
int keep_dims = 1;
......@@ -288,7 +288,7 @@ struct onnx_parser
if(keep_dims == 0)
{
auto ins = prog.add_instruction(op::argmax{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{static_cast<int64_t>(axis)}}, ins);
return prog.add_instruction(op::squeeze{{axis}}, ins);
}
else
{
......@@ -300,10 +300,10 @@ struct onnx_parser
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int axis = 0;
int64_t axis = 0;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
axis = static_cast<int64_t>(parse_value(attributes.at("axis")).at<int>());
}
int keep_dims = 1;
......@@ -315,7 +315,7 @@ struct onnx_parser
if(keep_dims == 0)
{
auto ins = prog.add_instruction(op::argmin{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{static_cast<int64_t>(axis)}}, ins);
return prog.add_instruction(op::squeeze{{axis}}, ins);
}
else
{
......
......@@ -12,7 +12,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void argmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
void argmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{
arg_op(argmax_op{}, stream, result, arg, axis);
}
......
......@@ -12,7 +12,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void argmin(hipStream_t stream, const argument& result, const argument& arg, int axis)
void argmin(hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{
arg_op(argmin_op{}, stream, result, arg, axis);
}
......
......@@ -70,7 +70,7 @@ struct argmin_op
};
template <class Op>
void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int axis)
void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{
auto arg_shape = arg.get_shape();
auto lens = arg_shape.lens();
......
......@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void argmax(hipStream_t stream, const argument& result, const argument& arg, int axis);
void argmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis);
} // namespace device
} // namespace gpu
......
......@@ -10,7 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void argmin(hipStream_t stream, const argument& result, const argument& arg, int axis);
void argmin(hipStream_t stream, const argument& result, const argument& arg, int64_t axis);
} // namespace device
} // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment