Commit 6cda085f authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix parse_softmax and parse_reshape to support the bert onnx file

parent cf5aede1
......@@ -260,14 +260,27 @@ struct onnx_parser
return prog.add_instruction(op, std::move(args));
}
instruction_ref
parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
// instruction_ref
// parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
// {
// auto dims = args.front()->get_shape().lens();
// auto r =
// prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, args.front());
// auto s = prog.add_instruction(op::softmax{}, r);
// return prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1])}}, s);
// }
instruction_ref parse_softmax(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
auto dims = args.front()->get_shape().lens();
auto r =
prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, args.front());
auto s = prog.add_instruction(op::softmax{}, r);
return prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1])}}, s);
int axis = 1;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
}
return prog.add_instruction(op::softmax{axis}, std::move(args));
}
instruction_ref parse_logsoftmax(const std::string&,
......@@ -460,8 +473,15 @@ struct onnx_parser
op::reshape op;
if(args.size() == 1)
{
literal s = parse_value(attributes.at("shape"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
if (contains(attributes, "shape"))
{
literal s = parse_value(attributes.at("shape"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
}
else
{
MIGRAPHX_THROW("Parse_reshape: shape attribute is needed when only one argument is provided!");
}
}
if(args.size() == 2)
{
......@@ -470,6 +490,12 @@ struct onnx_parser
MIGRAPHX_THROW("Dynamic shape is not supported.");
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
}
if (!args[0]->get_shape().standard())
{
args[0] = prog.add_instruction(op::contiguous{}, args[0]);
}
return prog.add_instruction(op, args[0]);
}
......
......@@ -413,9 +413,7 @@ TEST_CASE(softmax_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
auto r = p.add_instruction(migraphx::op::reshape{{1, 3, 1, 1}}, l0);
auto s = p.add_instruction(migraphx::op::softmax{}, r);
p.add_instruction(migraphx::op::reshape{{1, 3}}, s);
p.add_instruction(migraphx::op::softmax{1}, l0);
auto prog = migraphx::parse_onnx("softmax_test.onnx");
EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment