Commit c8017873 authored by Scott Thornton's avatar Scott Thornton
Browse files

Formatting

parent d3987641
...@@ -105,7 +105,8 @@ struct onnx_parser ...@@ -105,7 +105,8 @@ struct onnx_parser
parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{ {
auto dims = args.front()->get_shape().lens(); auto dims = args.front()->get_shape().lens();
auto r = prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, args.front()); auto r =
prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, args.front());
auto s = prog.add_instruction(op::softmax{}, r); auto s = prog.add_instruction(op::softmax{}, r);
return prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1])}}, s); return prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1])}}, s);
} }
...@@ -231,10 +232,10 @@ struct onnx_parser ...@@ -231,10 +232,10 @@ struct onnx_parser
instruction_ref instruction_ref
parse_batchnorm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_batchnorm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
float epsilon = 1e-5f; float epsilon = 1e-5f;
float momentum = 0.9f; float momentum = 0.9f;
op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial; op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
bool is_test = false; bool is_test = false;
if(contains(attributes, "epsilon")) if(contains(attributes, "epsilon"))
{ {
epsilon = parse_value(attributes.at("epsilon")).at<float>(); epsilon = parse_value(attributes.at("epsilon")).at<float>();
......
...@@ -72,7 +72,7 @@ void pytorch_conv_relu_maxpool_x2() ...@@ -72,7 +72,7 @@ void pytorch_conv_relu_maxpool_x2()
auto l4 = p.add_instruction(migraph::op::broadcast{axis}, l3, l2); auto l4 = p.add_instruction(migraph::op::broadcast{axis}, l3, l2);
auto l5 = p.add_instruction(migraph::op::add{}, l3, l4); auto l5 = p.add_instruction(migraph::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::op::activation{"relu"}, l5); auto l6 = p.add_instruction(migraph::op::activation{"relu"}, l5);
auto l7 = p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6); auto l7 = p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
auto l8 = p.add_parameter("3", {migraph::shape::float_type, {1, 5, 5, 5}}); auto l8 = p.add_parameter("3", {migraph::shape::float_type, {1, 5, 5, 5}});
auto l9 = p.add_parameter("4", {migraph::shape::float_type, {1}}); auto l9 = p.add_parameter("4", {migraph::shape::float_type, {1}});
......
...@@ -117,16 +117,18 @@ void reshape_shape() ...@@ -117,16 +117,18 @@ void reshape_shape()
void flatten_shape() void flatten_shape()
{ {
migraph::shape input{migraph::shape::float_type, {2, 4, 6, 8}}; migraph::shape input{migraph::shape::float_type, {2, 4, 6, 8}};
expect_shape( expect_shape(migraph::shape{migraph::shape::float_type, {1, 2 * 4 * 6 * 8}},
migraph::shape{migraph::shape::float_type, {1, 2 * 4 * 6 * 8}}, migraph::op::flatten{0}, input); migraph::op::flatten{0},
input);
expect_shape( expect_shape(
migraph::shape{migraph::shape::float_type, {2, 4 * 6 * 8}}, migraph::op::flatten{1}, input); migraph::shape{migraph::shape::float_type, {2, 4 * 6 * 8}}, migraph::op::flatten{1}, input);
expect_shape( expect_shape(
migraph::shape{migraph::shape::float_type, {2 * 4, 6 * 8}}, migraph::op::flatten{2}, input); migraph::shape{migraph::shape::float_type, {2 * 4, 6 * 8}}, migraph::op::flatten{2}, input);
expect_shape( expect_shape(
migraph::shape{migraph::shape::float_type, {2 * 4 * 6, 8}}, migraph::op::flatten{3}, input); migraph::shape{migraph::shape::float_type, {2 * 4 * 6, 8}}, migraph::op::flatten{3}, input);
expect_shape( expect_shape(migraph::shape{migraph::shape::float_type, {2 * 4 * 6 * 8, 1}},
migraph::shape{migraph::shape::float_type, {2 * 4 * 6 * 8, 1}}, migraph::op::flatten{4}, input); migraph::op::flatten{4},
input);
throws_shape(migraph::op::flatten{5}, input); throws_shape(migraph::op::flatten{5}, input);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment