"src/vscode:/vscode.git/clone" did not exist on "e78748337ee0458cc067eb4fb71ee2a4e1bc0407"
Commit eb1e4353 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix comments. improve the implementation of parsing the ConstantFill operator,...

fix comments. improve the implementation of parsing the ConstantFill operator, and add more tests for better code coverage.
parent 45867925
...@@ -658,7 +658,11 @@ struct gather ...@@ -658,7 +658,11 @@ struct gather
in_idx = out_idx; in_idx = out_idx;
// max dimension in axis // max dimension in axis
std::size_t max_dim = args[0].get_shape().lens()[axis]; std::size_t max_dim = args[0].get_shape().lens()[axis];
std::size_t idx = args[1].at<std::size_t>(out_idx[axis]); std::vector<std::size_t> vec_indices(args[1].get_shape().lens().size());
args[1].visit([&](auto indices) {
vec_indices.assign(indices.begin(), indices.end());
});
std::size_t idx = vec_indices.at(out_idx[axis]);
if(idx >= max_dim) if(idx >= max_dim)
{ {
MIGRAPHX_THROW("Gather, indices are out of range in input tensor"); MIGRAPHX_THROW("Gather, indices are out of range in input tensor");
...@@ -670,8 +674,8 @@ struct gather ...@@ -670,8 +674,8 @@ struct gather
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
std::vector<std::size_t> in_idx; std::vector<std::size_t> in_idx;
shape_for_each(output.get_shape(), [&](const auto& idx) {
this->compute_index(idx, args, in_idx); this->compute_index(idx, args, in_idx);
output(idx.begin(), idx.end()) = input(in_idx.begin(), in_idx.end()); output(idx.begin(), idx.end()) = input(in_idx.begin(), in_idx.end());
}); });
......
...@@ -556,6 +556,7 @@ struct onnx_parser ...@@ -556,6 +556,7 @@ struct onnx_parser
return prog.add_literal(migraphx::literal{s, vec_shape}); return prog.add_literal(migraphx::literal{s, vec_shape});
} }
// Use a literal instruction to replace the constantFill operator. In RNN, input shape // Use a literal instruction to replace the constantFill operator. In RNN, input shape
// and value are fixed, so no need to do the actual computation for the constantFill // and value are fixed, so no need to do the actual computation for the constantFill
// operator // operator
...@@ -563,11 +564,6 @@ struct onnx_parser ...@@ -563,11 +564,6 @@ struct onnx_parser
attribute_map attributes, attribute_map attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
{ {
if(args.size() != 1)
{
MIGRAPHX_THROW("Constantfill, MIGraphX only handle the case with 1 operand");
}
int input_as_shape = 0; int input_as_shape = 0;
int dtype = 1; int dtype = 1;
float value = 0.0f; float value = 0.0f;
...@@ -588,28 +584,50 @@ struct onnx_parser ...@@ -588,28 +584,50 @@ struct onnx_parser
value = parse_value(attributes.at("value")).at<float>(); value = parse_value(attributes.at("value")).at<float>();
} }
if (contains(attributes, "extra_shape")) {
MIGRAPHX_THROW("ConstantFill, cannot handle extra shape attribute");
}
if(input_as_shape == 1) if(input_as_shape == 1)
{ {
if (args.size() != 1)
{
MIGRAPHX_THROW("ConstantFill, need an input argument as output shape");
}
if (contains(attributes, "shape")) {
MIGRAPHX_THROW("ConstantFill, cannot set the shape argument and pass in an input at the same time");
}
migraphx::argument in = args[0]->eval(); migraphx::argument in = args[0]->eval();
if(in.empty()) if(in.empty())
{ {
MIGRAPHX_THROW( MIGRAPHX_THROW(
"ConstantFill, cannot handle dynamic shape as input for ConstantFill"); "ConstantFill, cannot handle dynamic shape as input");
} }
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
migraphx::shape s(type, dims); migraphx::shape s(type, dims);
return prog.add_literal(migraphx::literal(s, {value})); std::vector<float> values(s.elements(), value);
return prog.add_literal(migraphx::literal(s, values));
} }
else if(input_as_shape == 0) else if(input_as_shape == 0)
{ {
std::vector<std::size_t> dims = args[0]->get_shape().lens(); if (!contains(attributes, "shape")) {
MIGRAPHX_THROW("ConstantFill, attribute output shape is needed");
}
literal ls = parse_value(attributes.at("shape"));
std::vector<std::size_t> dims(ls.get_shape().elements());
ls.visit([&] (auto s) { dims.assign(s.begin(), s.end()); } );
migraphx::shape s{type, dims}; migraphx::shape s{type, dims};
return prog.add_literal(migraphx::literal(s, {value})); std::vector<float> values(s.elements(), value);
return prog.add_literal(migraphx::literal(s, values));
} }
else else
{ {
MIGRAPHX_THROW("Wrong input for ConstantFill"); MIGRAPHX_THROW("ConstantFill, wrong value of attribute input_as_shape");
} }
} }
......
...@@ -494,6 +494,33 @@ TEST_CASE(constant_test) ...@@ -494,6 +494,33 @@ TEST_CASE(constant_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(constant_fill_test)
{
{
migraphx::program p;
auto l0 = p.add_literal(migraphx::literal{{migraphx::shape::int32_type, {2}}, {2, 3}});
std::vector<std::size_t> dims(l0->get_shape().elements());
migraphx::literal ls = l0->get_literal();
ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
migraphx::shape s{migraphx::shape::float_type, dims};
std::vector<float> value(s.elements(), 1.0);
p.add_literal(migraphx::literal{s, value});
auto prog = migraphx::parse_onnx("const_fill1.onnx");
EXPECT(p == prog);
}
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> value(s.elements(), 1.0);
p.add_literal(migraphx::literal{s, value});
auto prog = migraphx::parse_onnx("const_fill2.onnx");
EXPECT(p == prog);
}
}
TEST_CASE(gemm_test) TEST_CASE(gemm_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment