Commit a9212aad authored by Paul's avatar Paul
Browse files

Use correct order for batchnorm

parent eef9de6f
...@@ -48,10 +48,10 @@ struct cpu_batch_norm_inference ...@@ -48,10 +48,10 @@ struct cpu_batch_norm_inference
double epsilon = op.epsilon; double epsilon = op.epsilon;
auto input = args[0]; auto input = args[0];
auto mini_batch_mean = args[1]; auto arg_gamma = args[1];
auto mini_batch_variance = args[2]; auto arg_bias = args[2];
auto arg_gamma = args[3]; auto mini_batch_mean = args[3];
auto arg_bias = args[4]; auto mini_batch_variance = args[4];
auto num_batch = output_shape.lens()[0]; auto num_batch = output_shape.lens()[0];
auto num_channels = output_shape.lens()[1]; auto num_channels = output_shape.lens()[1];
......
...@@ -48,10 +48,10 @@ struct miopen_batch_norm_inference ...@@ -48,10 +48,10 @@ struct miopen_batch_norm_inference
y_desc.get(), y_desc.get(),
args[5].implicit(), args[5].implicit(),
bn_desc.get(), bn_desc.get(),
args[3].implicit(),
args[4].implicit(),
args[1].implicit(), args[1].implicit(),
args[2].implicit(), args[2].implicit(),
args[3].implicit(),
args[4].implicit(),
op.epsilon); op.epsilon);
return args[5]; return args[5];
......
...@@ -34,7 +34,7 @@ void batch_norm_inference_test() ...@@ -34,7 +34,7 @@ void batch_norm_inference_test()
auto mean = p.add_literal(migraph::literal{vars, mean_data}); auto mean = p.add_literal(migraph::literal{vars, mean_data});
auto variance = p.add_literal(migraph::literal{vars, variance_data}); auto variance = p.add_literal(migraph::literal{vars, variance_data});
p.add_instruction(migraph::batch_norm_inference{}, x, mean, variance, scale, bias); p.add_instruction(migraph::batch_norm_inference{}, x, scale, bias, mean, variance);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
......
...@@ -345,11 +345,11 @@ struct test_batchnorm_inference_2 ...@@ -345,11 +345,11 @@ struct test_batchnorm_inference_2
migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}}; migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}};
migraph::shape vars{migraph::shape::float_type, {channels}}; migraph::shape vars{migraph::shape::float_type, {channels}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 0))); auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1)));
auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1))); auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2)));
auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2))); auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3)));
auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3))); auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 4)));
p.add_instruction(migraph::batch_norm_inference{}, x, mean, variance, scale, bias); p.add_instruction(migraph::batch_norm_inference{}, x, scale, bias, mean, variance);
return p; return p;
} }
}; };
...@@ -368,11 +368,11 @@ struct test_batchnorm_inference ...@@ -368,11 +368,11 @@ struct test_batchnorm_inference
migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}}; migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}};
migraph::shape vars{migraph::shape::float_type, {channels}}; migraph::shape vars{migraph::shape::float_type, {channels}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 0))); auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1)));
auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1))); auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2)));
auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2))); auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3)));
auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3))); auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 4)));
p.add_instruction(migraph::batch_norm_inference{}, x, mean, variance, scale, bias); p.add_instruction(migraph::batch_norm_inference{}, x, scale, bias, mean, variance);
return p; return p;
} }
}; };
...@@ -385,16 +385,16 @@ struct test_conv_bn_relu_pooling ...@@ -385,16 +385,16 @@ struct test_conv_bn_relu_pooling
migraph::shape xs{migraph::shape::float_type, {1, 3, 224, 224}}; migraph::shape xs{migraph::shape::float_type, {1, 3, 224, 224}};
migraph::shape ws{migraph::shape::float_type, {64, 3, 7, 7}}; migraph::shape ws{migraph::shape::float_type, {64, 3, 7, 7}};
migraph::shape vars{migraph::shape::float_type, {64}};
auto x = p.add_parameter("x", xs); auto x = p.add_parameter("x", xs);
auto w = p.add_parameter("w", ws); auto w = p.add_parameter("w", ws);
auto conv = p.add_instruction(migraph::convolution{{3, 3}, {2, 2}, {1, 1}}, x, w); auto conv = p.add_instruction(migraph::convolution{{3, 3}, {2, 2}, {1, 1}}, x, w);
migraph::shape vars{migraph::shape::float_type, {64}}; auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1)));
auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 0))); auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2)));
auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1))); auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3)));
auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2))); auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 4)));
auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3)));
auto bn = auto bn =
p.add_instruction(migraph::batch_norm_inference{}, conv, mean, variance, scale, bias); p.add_instruction(migraph::batch_norm_inference{}, conv, scale, bias, mean, variance);
auto relu = p.add_instruction(migraph::activation{"relu"}, bn); auto relu = p.add_instruction(migraph::activation{"relu"}, bn);
p.add_instruction(migraph::pooling{"average", {1, 1}, {2, 2}, {3, 3}}, relu); p.add_instruction(migraph::pooling{"average", {1, 1}, {2, 2}, {3, 3}}, relu);
return p; return p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment