Commit a79ab4d7 authored by Khalique's avatar Khalique
Browse files

formatting

parent 5168b178
...@@ -104,25 +104,25 @@ struct cpu_LRN ...@@ -104,25 +104,25 @@ struct cpu_LRN
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
int n_batch = output_shape.lens()[0]; int n_batch = output_shape.lens()[0];
int channels = output_shape.lens()[1]; int channels = output_shape.lens()[1];
int height = output_shape.lens()[2]; int height = output_shape.lens()[2];
int width = output_shape.lens()[3]; int width = output_shape.lens()[3];
float alphaoverarea = op.alpha / op.size; float alphaoverarea = op.alpha / op.size;
int radius = (op.size - 1) / 2; int radius = (op.size - 1) / 2;
dfor(n_batch, height, width)([&](int b, int h, int w) { dfor(n_batch, height, width)([&](int b, int h, int w) {
float scale = 0; float scale = 0;
dfor(channels)([&](int c) { dfor(channels)([&](int c) {
auto start = (c - radius) < 0 ? 0 : (c - radius); auto start = (c - radius) < 0 ? 0 : (c - radius);
auto end = (c + radius) > channels ? channels : (c + radius); auto end = (c + radius) > channels ? channels : (c + radius);
for(auto k = start; k < end; ++k) for(auto k = start; k < end; ++k)
{ {
scale += std::pow(input(b, k, h, w), 2); scale += std::pow(input(b, k, h, w), 2);
} }
scale *= alphaoverarea; scale *= alphaoverarea;
scale += op.bias; scale += op.bias;
scale = std::pow(scale, -op.beta); scale = std::pow(scale, -op.beta);
output(b, c, h, w) = input(b, c, h, w) * scale; output(b, c, h, w) = input(b, c, h, w) * scale;
}); });
}); });
......
...@@ -15,22 +15,22 @@ shape miopen_LRN::compute_shape(const std::vector<shape>& inputs) const ...@@ -15,22 +15,22 @@ shape miopen_LRN::compute_shape(const std::vector<shape>& inputs) const
} }
argument miopen_LRN::compute(context& ctx, argument miopen_LRN::compute(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
auto x_desc = make_tensor(args[0].get_shape()); auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
miopenLRNForward(ctx.get_stream().get_miopen(), miopenLRNForward(ctx.get_stream().get_miopen(),
ldesc.get(), ldesc.get(),
&alpha, &alpha,
x_desc.get(), x_desc.get(),
args[0].implicit(), args[0].implicit(),
&beta, &beta,
y_desc.get(), y_desc.get(),
args[1].implicit(), args[1].implicit(),
false, false,
nullptr); nullptr);
return args[1]; return args[1];
} }
......
...@@ -88,12 +88,7 @@ inline pooling_descriptor make_pooling(const migraph::op::pooling& op) ...@@ -88,12 +88,7 @@ inline pooling_descriptor make_pooling(const migraph::op::pooling& op)
inline LRN_descriptor make_LRN(const migraph::op::LRN& op) inline LRN_descriptor make_LRN(const migraph::op::LRN& op)
{ {
auto ldesc = make_obj<LRN_descriptor>(&miopenCreateLRNDescriptor); auto ldesc = make_obj<LRN_descriptor>(&miopenCreateLRNDescriptor);
miopenSetLRNDescriptor(ldesc.get(), miopenSetLRNDescriptor(ldesc.get(), miopenLRNCrossChannel, op.size, op.alpha, op.beta, op.bias);
miopenLRNCrossChannel,
op.size,
op.alpha,
op.beta,
op.bias);
return ldesc; return ldesc;
} }
......
...@@ -139,10 +139,11 @@ struct miopen_apply ...@@ -139,10 +139,11 @@ struct miopen_apply
instruction_ref apply_LRN(instruction_ref ins) instruction_ref apply_LRN(instruction_ref ins)
{ {
auto&& op = any_cast<op::LRN>(ins->get_operator()); auto&& op = any_cast<op::LRN>(ins->get_operator());
auto ldesc = make_LRN(op); auto ldesc = make_LRN(op);
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(ins, miopen_LRN{std::move(ldesc)}, ins->inputs().at(0), output); return prog->replace_instruction(
ins, miopen_LRN{std::move(ldesc)}, ins->inputs().at(0), output);
} }
instruction_ref apply_relu(instruction_ref ins) instruction_ref apply_relu(instruction_ref ins)
......
...@@ -589,11 +589,10 @@ TEST_CASE(LRN_test) ...@@ -589,11 +589,10 @@ TEST_CASE(LRN_test)
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(5); std::vector<float> results_vector(5);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2/1.000075, 1/1.00009, 0/1.000145, 1/1.00009, 2/1.000075}; std::vector<float> gold = {-2 / 1.000075, 1 / 1.00009, 0 / 1.000145, 1 / 1.00009, 2 / 1.000075};
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
TEST_CASE(imagescaler_test) TEST_CASE(imagescaler_test)
{ {
migraph::program p; migraph::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment