Commit a9212aad authored by Paul's avatar Paul
Browse files

Use correct order for batchnorm

parent eef9de6f
......@@ -48,10 +48,10 @@ struct cpu_batch_norm_inference
double epsilon = op.epsilon;
auto input = args[0];
auto mini_batch_mean = args[1];
auto mini_batch_variance = args[2];
auto arg_gamma = args[3];
auto arg_bias = args[4];
auto arg_gamma = args[1];
auto arg_bias = args[2];
auto mini_batch_mean = args[3];
auto mini_batch_variance = args[4];
auto num_batch = output_shape.lens()[0];
auto num_channels = output_shape.lens()[1];
......
......@@ -48,10 +48,10 @@ struct miopen_batch_norm_inference
y_desc.get(),
args[5].implicit(),
bn_desc.get(),
args[3].implicit(),
args[4].implicit(),
args[1].implicit(),
args[2].implicit(),
args[3].implicit(),
args[4].implicit(),
op.epsilon);
return args[5];
......
......@@ -34,7 +34,7 @@ void batch_norm_inference_test()
auto mean = p.add_literal(migraph::literal{vars, mean_data});
auto variance = p.add_literal(migraph::literal{vars, variance_data});
p.add_instruction(migraph::batch_norm_inference{}, x, mean, variance, scale, bias);
p.add_instruction(migraph::batch_norm_inference{}, x, scale, bias, mean, variance);
p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({});
......
......@@ -345,11 +345,11 @@ struct test_batchnorm_inference_2
migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}};
migraph::shape vars{migraph::shape::float_type, {channels}};
auto x = p.add_parameter("x", s);
auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 0)));
auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1)));
auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2)));
auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3)));
p.add_instruction(migraph::batch_norm_inference{}, x, mean, variance, scale, bias);
auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1)));
auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2)));
auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3)));
auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 4)));
p.add_instruction(migraph::batch_norm_inference{}, x, scale, bias, mean, variance);
return p;
}
};
......@@ -368,11 +368,11 @@ struct test_batchnorm_inference
migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}};
migraph::shape vars{migraph::shape::float_type, {channels}};
auto x = p.add_parameter("x", s);
auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 0)));
auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1)));
auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2)));
auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3)));
p.add_instruction(migraph::batch_norm_inference{}, x, mean, variance, scale, bias);
auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1)));
auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2)));
auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3)));
auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 4)));
p.add_instruction(migraph::batch_norm_inference{}, x, scale, bias, mean, variance);
return p;
}
};
......@@ -385,16 +385,16 @@ struct test_conv_bn_relu_pooling
migraph::shape xs{migraph::shape::float_type, {1, 3, 224, 224}};
migraph::shape ws{migraph::shape::float_type, {64, 3, 7, 7}};
migraph::shape vars{migraph::shape::float_type, {64}};
auto x = p.add_parameter("x", xs);
auto w = p.add_parameter("w", ws);
auto conv = p.add_instruction(migraph::convolution{{3, 3}, {2, 2}, {1, 1}}, x, w);
migraph::shape vars{migraph::shape::float_type, {64}};
auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 0)));
auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1)));
auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2)));
auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3)));
auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1)));
auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2)));
auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3)));
auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 4)));
auto bn =
p.add_instruction(migraph::batch_norm_inference{}, conv, mean, variance, scale, bias);
p.add_instruction(migraph::batch_norm_inference{}, conv, scale, bias, mean, variance);
auto relu = p.add_instruction(migraph::activation{"relu"}, bn);
p.add_instruction(migraph::pooling{"average", {1, 1}, {2, 2}, {3, 3}}, relu);
return p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment