Commit f23196d4 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change the axis attribute in gather to int type to follow onnx specification.

parent ecbb4de5
...@@ -635,17 +635,25 @@ struct as_shape ...@@ -635,17 +635,25 @@ struct as_shape
struct gather struct gather
{ {
std::size_t axis = 0; mutable int axis = 0;
std::string name() const { return "gather"; } std::string name() const { return "gather"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2);
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
if(axis >= lens.size()) int n_dim = static_cast<int>(lens.size());
if(axis >= n_dim || axis < -n_dim)
{ {
MIGRAPHX_THROW("Gather, axis is out of range."); MIGRAPHX_THROW("Gather: axis is out of range.");
} }
// negative axis means counting dimensions from back
if (axis < 0)
{
axis += n_dim;
}
auto type = inputs[0].type(); auto type = inputs[0].type();
lens[axis] = inputs[1].elements(); lens[axis] = inputs[1].elements();
......
...@@ -362,7 +362,7 @@ struct onnx_parser ...@@ -362,7 +362,7 @@ struct onnx_parser
instruction_ref instruction_ref
parse_gather(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_gather(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
std::size_t axis = 0; int axis = 0;
if(contains(attributes, "axis")) if(contains(attributes, "axis"))
{ {
axis = parse_value(attributes.at("axis")).at<int>(); axis = parse_value(attributes.at("axis")).at<int>();
......
...@@ -113,7 +113,7 @@ TEST_CASE(gather_test) ...@@ -113,7 +113,7 @@ TEST_CASE(gather_test)
migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}}; migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}};
std::vector<int> indices{0, 2}; std::vector<int> indices{0, 2};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 0; int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, a0, a1); p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -133,7 +133,7 @@ TEST_CASE(gather_test) ...@@ -133,7 +133,7 @@ TEST_CASE(gather_test)
migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}}; migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}};
std::vector<int> indices{0, 2}; std::vector<int> indices{0, 2};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 1; int axis = 1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1); p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
......
...@@ -944,7 +944,7 @@ struct test_gather ...@@ -944,7 +944,7 @@ struct test_gather
std::vector<int> indices{1, 2, 2, 1}; std::vector<int> indices{1, 2, 2, 1};
auto a0 = p.add_parameter("data", s); auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 0; int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, a0, a1); p.add_instruction(migraphx::op::gather{axis}, a0, a1);
return p; return p;
} }
......
...@@ -417,7 +417,7 @@ TEST_CASE(gather_test) ...@@ -417,7 +417,7 @@ TEST_CASE(gather_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = p.add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3}}); auto l1 = p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3}});
std::size_t axis = 1; int axis = 1;
p.add_instruction(migraphx::op::gather{axis}, l0, l1); p.add_instruction(migraphx::op::gather{axis}, l0, l1);
auto prog = migraphx::parse_onnx("gather_test.onnx"); auto prog = migraphx::parse_onnx("gather_test.onnx");
...@@ -432,7 +432,7 @@ TEST_CASE(shape_gather_test) ...@@ -432,7 +432,7 @@ TEST_CASE(shape_gather_test)
p.add_literal(migraphx::shape{migraphx::shape::int64_type, {3}}, l0->get_shape().lens()); p.add_literal(migraphx::shape{migraphx::shape::int64_type, {3}}, l0->get_shape().lens());
migraphx::shape const_shape{migraphx::shape::int32_type, {1}}; migraphx::shape const_shape{migraphx::shape::int32_type, {1}};
auto l2 = p.add_literal(migraphx::literal{const_shape, {1}}); auto l2 = p.add_literal(migraphx::literal{const_shape, {1}});
std::size_t axis = 0; int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, l1, l2); p.add_instruction(migraphx::op::gather{axis}, l1, l2);
auto prog = migraphx::parse_onnx("shape_gather.onnx"); auto prog = migraphx::parse_onnx("shape_gather.onnx");
......
...@@ -217,7 +217,7 @@ TEST_CASE(gather) ...@@ -217,7 +217,7 @@ TEST_CASE(gather)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}}; migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
std::size_t axis = 1; int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 6, 4, 5}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 6, 4, 5}},
migraphx::op::gather{axis}, migraphx::op::gather{axis},
input, input,
...@@ -227,7 +227,7 @@ TEST_CASE(gather) ...@@ -227,7 +227,7 @@ TEST_CASE(gather)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}}; migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
std::size_t axis = 4; int axis = 4;
throws_shape(migraphx::op::gather{axis}, input, indices); throws_shape(migraphx::op::gather{axis}, input, indices);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment