"vscode:/vscode.git/clone" did not exist on "dbed69058c88ddf42914e6ab3a9b6ea12e15b12a"
Commit f23196d4 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change the axis attribute in gather to int type to follow onnx specification.

parent ecbb4de5
...@@ -635,17 +635,25 @@ struct as_shape ...@@ -635,17 +635,25 @@ struct as_shape
struct gather struct gather
{ {
std::size_t axis = 0; mutable int axis = 0;
std::string name() const { return "gather"; } std::string name() const { return "gather"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2);
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
if(axis >= lens.size()) int n_dim = static_cast<int>(lens.size());
if(axis >= n_dim || axis < -n_dim)
{ {
MIGRAPHX_THROW("Gather, axis is out of range."); MIGRAPHX_THROW("Gather: axis is out of range.");
} }
// negative axis means counting dimensions from back
if (axis < 0)
{
axis += n_dim;
}
auto type = inputs[0].type(); auto type = inputs[0].type();
lens[axis] = inputs[1].elements(); lens[axis] = inputs[1].elements();
......
...@@ -362,7 +362,7 @@ struct onnx_parser ...@@ -362,7 +362,7 @@ struct onnx_parser
instruction_ref instruction_ref
parse_gather(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_gather(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
std::size_t axis = 0; int axis = 0;
if(contains(attributes, "axis")) if(contains(attributes, "axis"))
{ {
axis = parse_value(attributes.at("axis")).at<int>(); axis = parse_value(attributes.at("axis")).at<int>();
......
...@@ -113,7 +113,7 @@ TEST_CASE(gather_test) ...@@ -113,7 +113,7 @@ TEST_CASE(gather_test)
migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}}; migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}};
std::vector<int> indices{0, 2}; std::vector<int> indices{0, 2};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 0; int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, a0, a1); p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -133,7 +133,7 @@ TEST_CASE(gather_test) ...@@ -133,7 +133,7 @@ TEST_CASE(gather_test)
migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}}; migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}};
std::vector<int> indices{0, 2}; std::vector<int> indices{0, 2};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 1; int axis = 1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1); p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
......
...@@ -944,7 +944,7 @@ struct test_gather ...@@ -944,7 +944,7 @@ struct test_gather
std::vector<int> indices{1, 2, 2, 1}; std::vector<int> indices{1, 2, 2, 1};
auto a0 = p.add_parameter("data", s); auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 0; int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, a0, a1); p.add_instruction(migraphx::op::gather{axis}, a0, a1);
return p; return p;
} }
......
...@@ -417,7 +417,7 @@ TEST_CASE(gather_test) ...@@ -417,7 +417,7 @@ TEST_CASE(gather_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = p.add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3}}); auto l1 = p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3}});
std::size_t axis = 1; int axis = 1;
p.add_instruction(migraphx::op::gather{axis}, l0, l1); p.add_instruction(migraphx::op::gather{axis}, l0, l1);
auto prog = migraphx::parse_onnx("gather_test.onnx"); auto prog = migraphx::parse_onnx("gather_test.onnx");
...@@ -432,7 +432,7 @@ TEST_CASE(shape_gather_test) ...@@ -432,7 +432,7 @@ TEST_CASE(shape_gather_test)
p.add_literal(migraphx::shape{migraphx::shape::int64_type, {3}}, l0->get_shape().lens()); p.add_literal(migraphx::shape{migraphx::shape::int64_type, {3}}, l0->get_shape().lens());
migraphx::shape const_shape{migraphx::shape::int32_type, {1}}; migraphx::shape const_shape{migraphx::shape::int32_type, {1}};
auto l2 = p.add_literal(migraphx::literal{const_shape, {1}}); auto l2 = p.add_literal(migraphx::literal{const_shape, {1}});
std::size_t axis = 0; int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, l1, l2); p.add_instruction(migraphx::op::gather{axis}, l1, l2);
auto prog = migraphx::parse_onnx("shape_gather.onnx"); auto prog = migraphx::parse_onnx("shape_gather.onnx");
......
...@@ -217,7 +217,7 @@ TEST_CASE(gather) ...@@ -217,7 +217,7 @@ TEST_CASE(gather)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}}; migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
std::size_t axis = 1; int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 6, 4, 5}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 6, 4, 5}},
migraphx::op::gather{axis}, migraphx::op::gather{axis},
input, input,
...@@ -227,7 +227,7 @@ TEST_CASE(gather) ...@@ -227,7 +227,7 @@ TEST_CASE(gather)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}}; migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
std::size_t axis = 4; int axis = 4;
throws_shape(migraphx::op::gather{axis}, input, indices); throws_shape(migraphx::op::gather{axis}, input, indices);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment