Commit 04e01f74 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fixed a comments and add two more tests for gather.

parent 1bdd55e8
...@@ -653,17 +653,13 @@ struct gather ...@@ -653,17 +653,13 @@ struct gather
} }
template <class T> template <class T>
void compute_index(const T& out_idx, const std::vector<argument>& args, T& in_idx) const void compute_index(const T& out_idx, const std::vector<std::size_t>& vec_indices, const std::size_t max_dim, T& in_idx) const
{ {
in_idx = out_idx; in_idx = out_idx;
// max dimension in axis
std::size_t max_dim = args[0].get_shape().lens()[axis];
std::vector<std::size_t> vec_indices(args[1].get_shape().lens().size());
args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); });
std::size_t idx = vec_indices.at(out_idx[axis]); std::size_t idx = vec_indices.at(out_idx[axis]);
if(idx >= max_dim) if(idx >= max_dim)
{ {
MIGRAPHX_THROW("Gather, indices are out of range in input tensor"); MIGRAPHX_THROW("Gather: indices are out of range in input tensor");
} }
in_idx[axis] = idx; in_idx[axis] = idx;
} }
...@@ -671,10 +667,14 @@ struct gather ...@@ -671,10 +667,14 @@ struct gather
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
// max dimension in axis
std::size_t max_dim = args[0].get_shape().lens()[axis];
std::vector<std::size_t> vec_indices;
args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); });
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
std::vector<std::size_t> in_idx; std::vector<std::size_t> in_idx;
shape_for_each(output.get_shape(), [&](const auto& idx) { shape_for_each(output.get_shape(), [&](const auto& idx) {
this->compute_index(idx, args, in_idx); this->compute_index(idx, vec_indices, max_dim, in_idx);
output(idx.begin(), idx.end()) = input(in_idx.begin(), in_idx.end()); output(idx.begin(), idx.end()) = input(in_idx.begin(), in_idx.end());
}); });
}); });
......
...@@ -546,7 +546,7 @@ struct onnx_parser ...@@ -546,7 +546,7 @@ struct onnx_parser
parse_shape(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_shape(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{ {
if(args.size() != 1) if(args.size() != 1)
MIGRAPHX_THROW("Shape, operator should have 1 operand"); MIGRAPHX_THROW("Shape: operator should have 1 operand");
std::vector<std::size_t> arg_shape = args[0]->get_shape().lens(); std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
std::vector<int64_t> vec_shape(arg_shape.size()); std::vector<int64_t> vec_shape(arg_shape.size());
migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()}); migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
...@@ -585,26 +585,26 @@ struct onnx_parser ...@@ -585,26 +585,26 @@ struct onnx_parser
if(contains(attributes, "extra_shape")) if(contains(attributes, "extra_shape"))
{ {
MIGRAPHX_THROW("ConstantFill, cannot handle extra shape attribute"); MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
} }
if(input_as_shape == 1) if(input_as_shape == 1)
{ {
if(args.size() != 1) if(args.size() != 1)
{ {
MIGRAPHX_THROW("ConstantFill, need an input argument as output shape"); MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
} }
if(contains(attributes, "shape")) if(contains(attributes, "shape"))
{ {
MIGRAPHX_THROW("ConstantFill, cannot set the shape argument and pass in an input " MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
"at the same time"); "at the same time");
} }
migraphx::argument in = args[0]->eval(); migraphx::argument in = args[0]->eval();
if(in.empty()) if(in.empty())
{ {
MIGRAPHX_THROW("ConstantFill, cannot handle dynamic shape as input"); MIGRAPHX_THROW("ConstantFill: cannot handle dynamic shape as input");
} }
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
...@@ -617,11 +617,11 @@ struct onnx_parser ...@@ -617,11 +617,11 @@ struct onnx_parser
{ {
if(!contains(attributes, "shape")) if(!contains(attributes, "shape"))
{ {
MIGRAPHX_THROW("ConstantFill, attribute output shape is needed"); MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
} }
literal ls = parse_value(attributes.at("shape")); literal ls = parse_value(attributes.at("shape"));
std::vector<std::size_t> dims(ls.get_shape().elements()); std::vector<std::size_t> dims;
ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); }); ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
migraphx::shape s{type, dims}; migraphx::shape s{type, dims};
std::vector<float> values(s.elements(), value); std::vector<float> values(s.elements(), value);
...@@ -629,7 +629,7 @@ struct onnx_parser ...@@ -629,7 +629,7 @@ struct onnx_parser
} }
else else
{ {
MIGRAPHX_THROW("ConstantFill, wrong value of attribute input_as_shape"); MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
} }
} }
......
...@@ -212,4 +212,22 @@ TEST_CASE(multibroadcast) ...@@ -212,4 +212,22 @@ TEST_CASE(multibroadcast)
} }
} }
TEST_CASE(gather)
{
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
std::size_t axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 6, 4, 5}},
migraphx::op::gather{axis}, input, indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
std::size_t axis = 4;
throws_shape(migraphx::op::gather{axis}, input, indices);
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment