Commit e17fb984 authored by Khalique's avatar Khalique
Browse files

initial progress on group conv

parent 84e7335e
......@@ -63,6 +63,7 @@ struct convolution
valid
};
padding_mode_t padding_mode = default_;
int group = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -86,7 +87,7 @@ struct convolution
return {t,
{
input.lens()[0],
weights.lens()[0],
weights.lens()[0] * group,
std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
......@@ -105,7 +106,7 @@ struct convolution
{
return {t,
{input.lens()[0],
weights.lens()[0],
weights.lens()[0] * group,
static_cast<std::size_t>(
std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
static_cast<std::size_t>(
......@@ -116,7 +117,7 @@ struct convolution
return {
t,
{input.lens()[0],
weights.lens()[0],
weights.lens()[0] * group,
static_cast<std::size_t>(std::ceil(
static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])),
static_cast<std::size_t>(std::ceil(
......
......@@ -220,6 +220,10 @@ struct onnx_parser
op.padding_mode = op::convolution::same;
}
}
if(contains(attributes, "group"))
{
op.group = parse_value(attributes.at("group")).at<int>();
}
if(args.size() == 3)
{
uint64_t axis = 1;
......
......@@ -54,14 +54,19 @@ inline tensor_descriptor make_tensor(const migraphx::shape& s)
inline convolution_descriptor make_conv(const migraphx::op::convolution& op)
{
auto c = make_obj<convolution_descriptor>(&miopenCreateConvolutionDescriptor);
miopenConvolutionMode_t c_mode = miopenConvolution;
if (op.group > 1)
c_mode = miopenGroupConv;
miopenInitConvolutionDescriptor(c.get(),
miopenConvolution,
c_mode,
op.padding[0],
op.padding[1],
op.stride[0],
op.stride[1],
op.dilation[0],
op.dilation[1]);
if (op.group > 1)
miopenSetConvolutionGroupCount(c.get(), op.group);
return c;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment