Commit 04e01f74 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fixed a comments and add two more tests for gather.

parent 1bdd55e8
......@@ -653,17 +653,13 @@ struct gather
}
template <class T>
void compute_index(const T& out_idx, const std::vector<argument>& args, T& in_idx) const
void compute_index(const T& out_idx, const std::vector<std::size_t>& vec_indices, const std::size_t max_dim, T& in_idx) const
{
in_idx = out_idx;
// max dimension in axis
std::size_t max_dim = args[0].get_shape().lens()[axis];
std::vector<std::size_t> vec_indices(args[1].get_shape().lens().size());
args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); });
std::size_t idx = vec_indices.at(out_idx[axis]);
if(idx >= max_dim)
{
MIGRAPHX_THROW("Gather, indices are out of range in input tensor");
MIGRAPHX_THROW("Gather: indices are out of range in input tensor");
}
in_idx[axis] = idx;
}
......@@ -671,10 +667,14 @@ struct gather
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
// max dimension in axis
std::size_t max_dim = args[0].get_shape().lens()[axis];
std::vector<std::size_t> vec_indices;
args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); });
visit_all(result, args[0])([&](auto output, auto input) {
std::vector<std::size_t> in_idx;
shape_for_each(output.get_shape(), [&](const auto& idx) {
this->compute_index(idx, args, in_idx);
this->compute_index(idx, vec_indices, max_dim, in_idx);
output(idx.begin(), idx.end()) = input(in_idx.begin(), in_idx.end());
});
});
......
......@@ -546,7 +546,7 @@ struct onnx_parser
parse_shape(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
if(args.size() != 1)
MIGRAPHX_THROW("Shape, operator should have 1 operand");
MIGRAPHX_THROW("Shape: operator should have 1 operand");
std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
std::vector<int64_t> vec_shape(arg_shape.size());
migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
......@@ -585,26 +585,26 @@ struct onnx_parser
if(contains(attributes, "extra_shape"))
{
MIGRAPHX_THROW("ConstantFill, cannot handle extra shape attribute");
MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
}
if(input_as_shape == 1)
{
if(args.size() != 1)
{
MIGRAPHX_THROW("ConstantFill, need an input argument as output shape");
MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
}
if(contains(attributes, "shape"))
{
MIGRAPHX_THROW("ConstantFill, cannot set the shape argument and pass in an input "
MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
"at the same time");
}
migraphx::argument in = args[0]->eval();
if(in.empty())
{
MIGRAPHX_THROW("ConstantFill, cannot handle dynamic shape as input");
MIGRAPHX_THROW("ConstantFill: cannot handle dynamic shape as input");
}
std::vector<std::size_t> dims;
......@@ -617,11 +617,11 @@ struct onnx_parser
{
if(!contains(attributes, "shape"))
{
MIGRAPHX_THROW("ConstantFill, attribute output shape is needed");
MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
}
literal ls = parse_value(attributes.at("shape"));
std::vector<std::size_t> dims(ls.get_shape().elements());
std::vector<std::size_t> dims;
ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
migraphx::shape s{type, dims};
std::vector<float> values(s.elements(), value);
......@@ -629,7 +629,7 @@ struct onnx_parser
}
else
{
MIGRAPHX_THROW("ConstantFill, wrong value of attribute input_as_shape");
MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
}
}
......
......@@ -212,4 +212,22 @@ TEST_CASE(multibroadcast)
}
}
TEST_CASE(gather)
{
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
std::size_t axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 6, 4, 5}},
migraphx::op::gather{axis}, input, indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
std::size_t axis = 4;
throws_shape(migraphx::op::gather{axis}, input, indices);
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment