Commit e17fb984 authored by Khalique's avatar Khalique
Browse files

initial progress on group conv

parent 84e7335e
...@@ -63,6 +63,7 @@ struct convolution ...@@ -63,6 +63,7 @@ struct convolution
valid valid
}; };
padding_mode_t padding_mode = default_; padding_mode_t padding_mode = default_;
int group = 1;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -86,7 +87,7 @@ struct convolution ...@@ -86,7 +87,7 @@ struct convolution
return {t, return {t,
{ {
input.lens()[0], input.lens()[0],
weights.lens()[0], weights.lens()[0] * group,
std::size_t(std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, 1,
(input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) + (input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
...@@ -105,7 +106,7 @@ struct convolution ...@@ -105,7 +106,7 @@ struct convolution
{ {
return {t, return {t,
{input.lens()[0], {input.lens()[0],
weights.lens()[0], weights.lens()[0] * group,
static_cast<std::size_t>( static_cast<std::size_t>(
std::ceil(static_cast<double>(input.lens()[2]) / stride[0])), std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
static_cast<std::size_t>( static_cast<std::size_t>(
...@@ -116,7 +117,7 @@ struct convolution ...@@ -116,7 +117,7 @@ struct convolution
return { return {
t, t,
{input.lens()[0], {input.lens()[0],
weights.lens()[0], weights.lens()[0] * group,
static_cast<std::size_t>(std::ceil( static_cast<std::size_t>(std::ceil(
static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])), static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])),
static_cast<std::size_t>(std::ceil( static_cast<std::size_t>(std::ceil(
......
...@@ -220,6 +220,10 @@ struct onnx_parser ...@@ -220,6 +220,10 @@ struct onnx_parser
op.padding_mode = op::convolution::same; op.padding_mode = op::convolution::same;
} }
} }
if(contains(attributes, "group"))
{
op.group = parse_value(attributes.at("group")).at<int>();
}
if(args.size() == 3) if(args.size() == 3)
{ {
uint64_t axis = 1; uint64_t axis = 1;
......
...@@ -54,14 +54,19 @@ inline tensor_descriptor make_tensor(const migraphx::shape& s) ...@@ -54,14 +54,19 @@ inline tensor_descriptor make_tensor(const migraphx::shape& s)
inline convolution_descriptor make_conv(const migraphx::op::convolution& op) inline convolution_descriptor make_conv(const migraphx::op::convolution& op)
{ {
auto c = make_obj<convolution_descriptor>(&miopenCreateConvolutionDescriptor); auto c = make_obj<convolution_descriptor>(&miopenCreateConvolutionDescriptor);
miopenConvolutionMode_t c_mode = miopenConvolution;
if (op.group > 1)
c_mode = miopenGroupConv;
miopenInitConvolutionDescriptor(c.get(), miopenInitConvolutionDescriptor(c.get(),
miopenConvolution, c_mode,
op.padding[0], op.padding[0],
op.padding[1], op.padding[1],
op.stride[0], op.stride[0],
op.stride[1], op.stride[1],
op.dilation[0], op.dilation[0],
op.dilation[1]); op.dilation[1]);
if (op.group > 1)
miopenSetConvolutionGroupCount(c.get(), op.group);
return c; return c;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment