Commit bf336b27 authored by wsttiger's avatar wsttiger
Browse files

Formatting

parent cb7db7a9
......@@ -99,10 +99,7 @@ struct instruction
});
}
shape get_shape() const
{
return result;
}
shape get_shape() const { return result; }
friend bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
......
......@@ -386,19 +386,19 @@ struct miopen_apply
// Not sure how to write this. Review and fix required
void apply_batch_norm_inference(instruction_ref ins)
{
auto&& op = any_cast<batch_norm_inference>(ins->op);
auto output = insert_allocation(ins, ins->result);
auto&& op = any_cast<batch_norm_inference>(ins->op);
auto output = insert_allocation(ins, ins->result);
shape old_shape = ins->arguments.at(1)->get_shape();
std::vector<int64_t> new_shape{1,static_cast<int64_t>(old_shape.elements()),1,1};
auto arg1 = prog->insert_instruction(ins, migraph::reshape{new_shape},
ins->arguments.at(1));
auto arg2 = prog->insert_instruction(ins, migraph::reshape{new_shape},
ins->arguments.at(2));
auto arg3 = prog->insert_instruction(ins, migraph::reshape{new_shape},
ins->arguments.at(3));
auto arg4 = prog->insert_instruction(ins, migraph::reshape{new_shape},
ins->arguments.at(4));
prog->replace_instruction(ins,
std::vector<int64_t> new_shape{1, static_cast<int64_t>(old_shape.elements()), 1, 1};
auto arg1 =
prog->insert_instruction(ins, migraph::reshape{new_shape}, ins->arguments.at(1));
auto arg2 =
prog->insert_instruction(ins, migraph::reshape{new_shape}, ins->arguments.at(2));
auto arg3 =
prog->insert_instruction(ins, migraph::reshape{new_shape}, ins->arguments.at(3));
auto arg4 =
prog->insert_instruction(ins, migraph::reshape{new_shape}, ins->arguments.at(4));
prog->replace_instruction(ins,
miopen_batch_norm_inference{op},
ins->arguments.at(0),
arg1,
......
......@@ -227,7 +227,7 @@ struct test_batchnorm_inference
migraph::program p;
migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}};
migraph::shape vars{migraph::shape::float_type, {1,channels,1,1}};
migraph::shape vars{migraph::shape::float_type, {1, channels, 1, 1}};
auto x = p.add_parameter("x", s);
auto mean = p.add_parameter("mean", vars);
auto variance = p.add_parameter("variance", vars);
......@@ -260,7 +260,7 @@ void batch_norm_inference_test()
const float output_val = scale_val * (x_val - mean_val) / (std::sqrt(variance_val)) + bias_val;
migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}};
migraph::shape vars{migraph::shape::float_type, {1,channels,1,1}};
migraph::shape vars{migraph::shape::float_type, {1, channels, 1, 1}};
std::vector<float> x_data(width * height * channels * batches);
std::vector<float> scale_data(channels);
std::vector<float> bias_data(channels);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment