Commit eb1e4353 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix comments. improve the implementation of parsing the ConstantFill operator,...

fix comments. improve the implementation of parsing the ConstantFill operator, and add more tests for better code coverage.
parent 45867925
......@@ -658,7 +658,11 @@ struct gather
in_idx = out_idx;
// max dimension in axis
std::size_t max_dim = args[0].get_shape().lens()[axis];
std::size_t idx = args[1].at<std::size_t>(out_idx[axis]);
std::vector<std::size_t> vec_indices(args[1].get_shape().lens().size());
args[1].visit([&](auto indices) {
vec_indices.assign(indices.begin(), indices.end());
});
std::size_t idx = vec_indices.at(out_idx[axis]);
if(idx >= max_dim)
{
MIGRAPHX_THROW("Gather, indices are out of range in input tensor");
......@@ -670,8 +674,8 @@ struct gather
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
std::vector<std::size_t> in_idx;
shape_for_each(output.get_shape(), [&](const auto& idx) {
std::vector<std::size_t> in_idx;
this->compute_index(idx, args, in_idx);
output(idx.begin(), idx.end()) = input(in_idx.begin(), in_idx.end());
});
......
......@@ -556,6 +556,7 @@ struct onnx_parser
return prog.add_literal(migraphx::literal{s, vec_shape});
}
// Use a literal instruction to replace the constantFill operator. In RNN, input shape
// and value are fixed, so no need to do the actual computation for the constantFill
// operator
......@@ -563,11 +564,6 @@ struct onnx_parser
attribute_map attributes,
std::vector<instruction_ref> args)
{
if(args.size() != 1)
{
MIGRAPHX_THROW("Constantfill, MIGraphX only handle the case with 1 operand");
}
int input_as_shape = 0;
int dtype = 1;
float value = 0.0f;
......@@ -588,28 +584,50 @@ struct onnx_parser
value = parse_value(attributes.at("value")).at<float>();
}
if (contains(attributes, "extra_shape")) {
MIGRAPHX_THROW("ConstantFill, cannot handle extra shape attribute");
}
if(input_as_shape == 1)
{
if (args.size() != 1)
{
MIGRAPHX_THROW("ConstantFill, need an input argument as output shape");
}
if (contains(attributes, "shape")) {
MIGRAPHX_THROW("ConstantFill, cannot set the shape argument and pass in an input at the same time");
}
migraphx::argument in = args[0]->eval();
if(in.empty())
{
MIGRAPHX_THROW(
"ConstantFill, cannot handle dynamic shape as input for ConstantFill");
"ConstantFill, cannot handle dynamic shape as input");
}
std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
migraphx::shape s(type, dims);
return prog.add_literal(migraphx::literal(s, {value}));
std::vector<float> values(s.elements(), value);
return prog.add_literal(migraphx::literal(s, values));
}
else if(input_as_shape == 0)
{
std::vector<std::size_t> dims = args[0]->get_shape().lens();
if (!contains(attributes, "shape")) {
MIGRAPHX_THROW("ConstantFill, attribute output shape is needed");
}
literal ls = parse_value(attributes.at("shape"));
std::vector<std::size_t> dims(ls.get_shape().elements());
ls.visit([&] (auto s) { dims.assign(s.begin(), s.end()); } );
migraphx::shape s{type, dims};
return prog.add_literal(migraphx::literal(s, {value}));
std::vector<float> values(s.elements(), value);
return prog.add_literal(migraphx::literal(s, values));
}
else
{
MIGRAPHX_THROW("Wrong input for ConstantFill");
MIGRAPHX_THROW("ConstantFill, wrong value of attribute input_as_shape");
}
}
......
......@@ -494,6 +494,33 @@ TEST_CASE(constant_test)
EXPECT(p == prog);
}
TEST_CASE(constant_fill_test)
{
{
migraphx::program p;
auto l0 = p.add_literal(migraphx::literal{{migraphx::shape::int32_type, {2}}, {2, 3}});
std::vector<std::size_t> dims(l0->get_shape().elements());
migraphx::literal ls = l0->get_literal();
ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
migraphx::shape s{migraphx::shape::float_type, dims};
std::vector<float> value(s.elements(), 1.0);
p.add_literal(migraphx::literal{s, value});
auto prog = migraphx::parse_onnx("const_fill1.onnx");
EXPECT(p == prog);
}
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> value(s.elements(), 1.0);
p.add_literal(migraphx::literal{s, value});
auto prog = migraphx::parse_onnx("const_fill2.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(gemm_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment