Commit c5353ce2 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change softmax implementation for tf

parent e33f42ed
...@@ -704,13 +704,15 @@ struct tf_parser ...@@ -704,13 +704,15 @@ struct tf_parser
} }
instruction_ref instruction_ref
parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_softmax(const std::string&, const attribute_map& attributes, std::vector<instruction_ref> args)
{ {
auto dims = args.front()->get_shape().lens(); int axis = 1;
auto r = if(contains(attributes, "axis"))
prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, args.front()); {
auto s = prog.add_instruction(op::softmax{}, r); axis = static_cast<int>(attributes.at("axis").i());
return prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1])}}, s); }
return prog.add_instruction(op::softmax{axis}, std::move(args));
} }
instruction_ref parse_squeeze(const std::string&, instruction_ref parse_squeeze(const std::string&,
......
...@@ -355,10 +355,7 @@ TEST_CASE(softmax_test) ...@@ -355,10 +355,7 @@ TEST_CASE(softmax_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
auto dims = l0->get_shape().lens(); p.add_instruction(migraphx::op::softmax{1}, l0);
auto r = p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, l0);
auto s = p.add_instruction(migraphx::op::softmax{}, r);
p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1])}}, s);
auto prog = optimize_tf("softmax_test.pb", false); auto prog = optimize_tf("softmax_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment