Commit 237c3dbb authored by Paul's avatar Paul
Browse files

Merge branch 'softmax'

parents e2630858 601122f8
......@@ -105,7 +105,7 @@ struct onnx_parser
parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
auto dims = args.front()->get_shape().lens();
auto r = prog.add_instruction(reshape{{long(dims[0]), 1, 1, long(dims[1])}}, args.front());
auto r = prog.add_instruction(reshape{{long(dims[0]), long(dims[1]), 1, 1}}, args.front());
auto s = prog.add_instruction(softmax{}, r);
return prog.add_instruction(reshape{{long(dims[0]), long(dims[1])}}, s);
}
......
......@@ -51,6 +51,8 @@ void verify_program(const std::string& name, F f, double tolerance = 100)
auto x = run_cpu(f);
auto y = run_gpu(f);
migraph::verify_args(name, x, y, tolerance);
// std::cout << "cpu: " << x << std::endl;
// std::cout << "gpu: " << y << std::endl;
}
void verify_instructions(const migraph::program& prog, double tolerance = 80)
......
......@@ -251,7 +251,7 @@ struct test_softmax2
migraph::program create_program() const
{
migraph::program p;
auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {1, 1000}});
auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {1, 1000, 1, 1}});
p.add_instruction(migraph::softmax{}, x);
return p;
}
......@@ -539,8 +539,7 @@ int main()
verify_program<test_add_broadcast4>();
verify_program<test_add_broadcast5>();
verify_program<test_softmax>();
// TODO: Add reshapes to make this a valid test case
// verify_program<test_softmax2>();
verify_program<test_softmax2>();
verify_program<test_conv>();
verify_program<test_conv_relu>();
verify_program<test_add_relu>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment