Commit 8d258ac4 authored by Khalique's avatar Khalique
Browse files

formatting

parent 1c3f847d
......@@ -112,11 +112,11 @@ struct cpu_convolution
{
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
auto in = input.get_shape().lens();
auto in = input.get_shape().lens();
auto in_h = in[2];
auto in_w = in[3];
auto wei = weights.get_shape().lens();
auto wei = weights.get_shape().lens();
auto wei_n = wei[0];
auto wei_c = wei[1];
auto wei_h = wei[2];
......@@ -127,14 +127,14 @@ struct cpu_convolution
output_shape.lens()[2],
output_shape.lens()[3])(
[&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) {
const int start_x = i * op.stride[0] - op.padding[0];
const int start_y = j * op.stride[1] - op.padding[1];
const int start_x = i * op.stride[0] - op.padding[0];
const int start_y = j * op.stride[1] - op.padding[1];
const int group_id = w / (wei_n / op.group);
double acc = 0;
dfor(wei_c, wei_h, wei_w)([&](std::size_t k, std::size_t x, std::size_t y) {
const int in_x = start_x + x;
const int in_y = start_y + y;
const int in_x = start_x + x;
const int in_y = start_y + y;
const int in_ch = group_id * wei_c + k;
if(in_x >= 0 && in_x < in_h && in_y >= 0 && in_y < in_w)
{
......
......@@ -137,7 +137,7 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
auto wei = ins->inputs().at(1)->get_shape();
assert(wei.lens().size() == 4);
auto conv = any_cast<miopen_convolution>(ins->get_operator());
if (conv.op.group > 1)
if(conv.op.group > 1)
return false;
if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd)
return false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment