Commit 601122f8 authored by Paul's avatar Paul
Browse files

Fix softmax in onnx file

parent 0418f719
...@@ -105,7 +105,7 @@ struct onnx_parser ...@@ -105,7 +105,7 @@ struct onnx_parser
parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{ {
auto dims = args.front()->get_shape().lens(); auto dims = args.front()->get_shape().lens();
auto r = prog.add_instruction(reshape{{long(dims[0]), 1, 1, long(dims[1])}}, args.front()); auto r = prog.add_instruction(reshape{{long(dims[0]), long(dims[1]), 1, 1}}, args.front());
auto s = prog.add_instruction(softmax{}, r); auto s = prog.add_instruction(softmax{}, r);
return prog.add_instruction(reshape{{long(dims[0]), long(dims[1])}}, s); return prog.add_instruction(reshape{{long(dims[0]), long(dims[1])}}, s);
} }
......
...@@ -51,6 +51,8 @@ void verify_program(const std::string& name, F f, double tolerance = 100) ...@@ -51,6 +51,8 @@ void verify_program(const std::string& name, F f, double tolerance = 100)
auto x = run_cpu(f); auto x = run_cpu(f);
auto y = run_gpu(f); auto y = run_gpu(f);
migraph::verify_args(name, x, y, tolerance); migraph::verify_args(name, x, y, tolerance);
// std::cout << "cpu: " << x << std::endl;
// std::cout << "gpu: " << y << std::endl;
} }
void verify_instructions(const migraph::program& prog, double tolerance = 80) void verify_instructions(const migraph::program& prog, double tolerance = 80)
......
...@@ -251,7 +251,7 @@ struct test_softmax2 ...@@ -251,7 +251,7 @@ struct test_softmax2
migraph::program create_program() const migraph::program create_program() const
{ {
migraph::program p; migraph::program p;
auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {1, 1000}}); auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {1, 1000, 1, 1}});
p.add_instruction(migraph::softmax{}, x); p.add_instruction(migraph::softmax{}, x);
return p; return p;
} }
...@@ -539,8 +539,7 @@ int main() ...@@ -539,8 +539,7 @@ int main()
verify_program<test_add_broadcast4>(); verify_program<test_add_broadcast4>();
verify_program<test_add_broadcast5>(); verify_program<test_add_broadcast5>();
verify_program<test_softmax>(); verify_program<test_softmax>();
// TODO: Add reshapes to make this a valid test case verify_program<test_softmax2>();
// verify_program<test_softmax2>();
verify_program<test_conv>(); verify_program<test_conv>();
verify_program<test_conv_relu>(); verify_program<test_conv_relu>();
verify_program<test_add_relu>(); verify_program<test_add_relu>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment