Commit 1c3f847d authored by Khalique's avatar Khalique
Browse files

initial testing of group conv

parent 7d0847de
...@@ -71,7 +71,8 @@ struct convolution ...@@ -71,7 +71,8 @@ struct convolution
return pack(f(self.padding, "padding"), return pack(f(self.padding, "padding"),
f(self.stride, "stride"), f(self.stride, "stride"),
f(self.dilation, "dilation"), f(self.dilation, "dilation"),
f(self.padding_mode, "padding_mode")); f(self.padding_mode, "padding_mode"),
f(self.group, "group"));
} }
std::string name() const { return "convolution"; } std::string name() const { return "convolution"; }
...@@ -87,7 +88,7 @@ struct convolution ...@@ -87,7 +88,7 @@ struct convolution
return {t, return {t,
{ {
input.lens()[0], input.lens()[0],
weights.lens()[0] * group, weights.lens()[0],
std::size_t(std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, 1,
(input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) + (input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
...@@ -106,7 +107,7 @@ struct convolution ...@@ -106,7 +107,7 @@ struct convolution
{ {
return {t, return {t,
{input.lens()[0], {input.lens()[0],
weights.lens()[0] * group, weights.lens()[0],
static_cast<std::size_t>( static_cast<std::size_t>(
std::ceil(static_cast<double>(input.lens()[2]) / stride[0])), std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
static_cast<std::size_t>( static_cast<std::size_t>(
...@@ -117,7 +118,7 @@ struct convolution ...@@ -117,7 +118,7 @@ struct convolution
return { return {
t, t,
{input.lens()[0], {input.lens()[0],
weights.lens()[0] * group, weights.lens()[0],
static_cast<std::size_t>(std::ceil( static_cast<std::size_t>(std::ceil(
static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])), static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])),
static_cast<std::size_t>(std::ceil( static_cast<std::size_t>(std::ceil(
......
...@@ -112,12 +112,15 @@ struct cpu_convolution ...@@ -112,12 +112,15 @@ struct cpu_convolution
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) { visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
auto in_h = input.get_shape().lens()[2]; auto in = input.get_shape().lens();
auto in_w = input.get_shape().lens()[3]; auto in_h = in[2];
auto in_w = in[3];
auto wei_c = weights.get_shape().lens()[1]; auto wei = weights.get_shape().lens();
auto wei_h = weights.get_shape().lens()[2]; auto wei_n = wei[0];
auto wei_w = weights.get_shape().lens()[3]; auto wei_c = wei[1];
auto wei_h = wei[2];
auto wei_w = wei[3];
dfor(output_shape.lens()[0], dfor(output_shape.lens()[0],
output_shape.lens()[1], output_shape.lens()[1],
...@@ -126,14 +129,16 @@ struct cpu_convolution ...@@ -126,14 +129,16 @@ struct cpu_convolution
[&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) { [&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) {
const int start_x = i * op.stride[0] - op.padding[0]; const int start_x = i * op.stride[0] - op.padding[0];
const int start_y = j * op.stride[1] - op.padding[1]; const int start_y = j * op.stride[1] - op.padding[1];
const int group_id = w / (wei_n / op.group);
double acc = 0; double acc = 0;
dfor(wei_c, wei_h, wei_w)([&](std::size_t k, std::size_t x, std::size_t y) { dfor(wei_c, wei_h, wei_w)([&](std::size_t k, std::size_t x, std::size_t y) {
const int in_x = start_x + x; const int in_x = start_x + x;
const int in_y = start_y + y; const int in_y = start_y + y;
const int in_ch = group_id * wei_c + k;
if(in_x >= 0 && in_x < in_h && in_y >= 0 && in_y < in_w) if(in_x >= 0 && in_x < in_h && in_y >= 0 && in_y < in_w)
{ {
acc += input(o, k, in_x, in_y) * weights(w, k, x, y); acc += input(o, in_ch, in_x, in_y) * weights(w, k, x, y);
} }
}); });
output(o, w, i, j) = acc; output(o, w, i, j) = acc;
......
...@@ -137,6 +137,8 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins) ...@@ -137,6 +137,8 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
auto wei = ins->inputs().at(1)->get_shape(); auto wei = ins->inputs().at(1)->get_shape();
assert(wei.lens().size() == 4); assert(wei.lens().size() == 4);
auto conv = any_cast<miopen_convolution>(ins->get_operator()); auto conv = any_cast<miopen_convolution>(ins->get_operator());
if (conv.op.group > 1)
return false;
if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd) if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd)
return false; return false;
auto op = conv.op; auto op = conv.op;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment