"include/vscode:/vscode.git/clone" did not exist on "25b7bd72bf263e6a4c8a8f23a0e4b57b5b3eec3a"
Unverified Commit 96f7ae5b authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge pull request #150 from ROCmSoftwarePlatform/group_conv

Group conv
parents 43222796 66bd7da6
...@@ -63,6 +63,7 @@ struct convolution ...@@ -63,6 +63,7 @@ struct convolution
valid valid
}; };
padding_mode_t padding_mode = default_; padding_mode_t padding_mode = default_;
int group = 1;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -70,7 +71,8 @@ struct convolution ...@@ -70,7 +71,8 @@ struct convolution
return pack(f(self.padding, "padding"), return pack(f(self.padding, "padding"),
f(self.stride, "stride"), f(self.stride, "stride"),
f(self.dilation, "dilation"), f(self.dilation, "dilation"),
f(self.padding_mode, "padding_mode")); f(self.padding_mode, "padding_mode"),
f(self.group, "group"));
} }
std::string name() const { return "convolution"; } std::string name() const { return "convolution"; }
......
...@@ -243,6 +243,10 @@ struct onnx_parser ...@@ -243,6 +243,10 @@ struct onnx_parser
op.padding_mode = op::convolution::same; op.padding_mode = op::convolution::same;
} }
} }
if(contains(attributes, "group"))
{
op.group = parse_value(attributes.at("group")).at<int>();
}
if(args.size() == 3) if(args.size() == 3)
{ {
uint64_t axis = 1; uint64_t axis = 1;
......
...@@ -112,28 +112,33 @@ struct cpu_convolution ...@@ -112,28 +112,33 @@ struct cpu_convolution
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) { visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
auto in_h = input.get_shape().lens()[2]; auto in = input.get_shape().lens();
auto in_w = input.get_shape().lens()[3]; auto in_h = in[2];
auto in_w = in[3];
auto wei_c = weights.get_shape().lens()[1]; auto wei = weights.get_shape().lens();
auto wei_h = weights.get_shape().lens()[2]; auto wei_n = wei[0];
auto wei_w = weights.get_shape().lens()[3]; auto wei_c = wei[1];
auto wei_h = wei[2];
auto wei_w = wei[3];
dfor(output_shape.lens()[0], dfor(output_shape.lens()[0],
output_shape.lens()[1], output_shape.lens()[1],
output_shape.lens()[2], output_shape.lens()[2],
output_shape.lens()[3])( output_shape.lens()[3])(
[&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) { [&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) {
const int start_x = i * op.stride[0] - op.padding[0]; const int start_x = i * op.stride[0] - op.padding[0];
const int start_y = j * op.stride[1] - op.padding[1]; const int start_y = j * op.stride[1] - op.padding[1];
const int group_id = w / (wei_n / op.group);
double acc = 0; double acc = 0;
dfor(wei_c, wei_h, wei_w)([&](std::size_t k, std::size_t x, std::size_t y) { dfor(wei_c, wei_h, wei_w)([&](std::size_t k, std::size_t x, std::size_t y) {
const int in_x = start_x + x; const int in_x = start_x + x;
const int in_y = start_y + y; const int in_y = start_y + y;
const int in_ch = group_id * wei_c + k;
if(in_x >= 0 && in_x < in_h && in_y >= 0 && in_y < in_w) if(in_x >= 0 && in_x < in_h && in_y >= 0 && in_y < in_w)
{ {
acc += input(o, k, in_x, in_y) * weights(w, k, x, y); acc += input(o, in_ch, in_x, in_y) * weights(w, k, x, y);
} }
}); });
output(o, w, i, j) = acc; output(o, w, i, j) = acc;
......
...@@ -137,6 +137,8 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins) ...@@ -137,6 +137,8 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
auto wei = ins->inputs().at(1)->get_shape(); auto wei = ins->inputs().at(1)->get_shape();
assert(wei.lens().size() == 4); assert(wei.lens().size() == 4);
auto conv = any_cast<miopen_convolution>(ins->get_operator()); auto conv = any_cast<miopen_convolution>(ins->get_operator());
if(conv.op.group > 1)
return false;
if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd) if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd)
return false; return false;
auto op = conv.op; auto op = conv.op;
......
...@@ -54,14 +54,19 @@ inline tensor_descriptor make_tensor(const migraphx::shape& s) ...@@ -54,14 +54,19 @@ inline tensor_descriptor make_tensor(const migraphx::shape& s)
inline convolution_descriptor make_conv(const migraphx::op::convolution& op) inline convolution_descriptor make_conv(const migraphx::op::convolution& op)
{ {
auto c = make_obj<convolution_descriptor>(&miopenCreateConvolutionDescriptor); auto c = make_obj<convolution_descriptor>(&miopenCreateConvolutionDescriptor);
miopenConvolutionMode_t c_mode = miopenConvolution;
if(op.group > 1)
c_mode = miopenGroupConv;
miopenInitConvolutionDescriptor(c.get(), miopenInitConvolutionDescriptor(c.get(),
miopenConvolution, c_mode,
op.padding[0], op.padding[0],
op.padding[1], op.padding[1],
op.stride[0], op.stride[0],
op.stride[1], op.stride[1],
op.dilation[0], op.dilation[0],
op.dilation[1]); op.dilation[1]);
if(op.group > 1)
miopenSetConvolutionGroupCount(c.get(), op.group);
return c; return c;
} }
......
...@@ -42,6 +42,7 @@ ...@@ -42,6 +42,7 @@
#include <migraphx/gpu/concat.hpp> #include <migraphx/gpu/concat.hpp>
#include <utility> #include <utility>
#include <functional> #include <functional>
#include <algorithm>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -534,6 +534,22 @@ struct test_conv2 ...@@ -534,6 +534,22 @@ struct test_conv2
} }
}; };
struct test_group_conv
{
migraphx::program create_program() const
{
migraphx::program p;
auto input =
p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 4, 16, 16}});
auto weights =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 1, 3, 3}});
migraphx::op::convolution op;
op.group = 4;
p.add_instruction(op, input, weights);
return p;
}
};
struct test_conv_relu struct test_conv_relu
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -1034,6 +1050,7 @@ int main() ...@@ -1034,6 +1050,7 @@ int main()
verify_program<test_softmax2>(); verify_program<test_softmax2>();
verify_program<test_conv>(); verify_program<test_conv>();
verify_program<test_conv2>(); verify_program<test_conv2>();
verify_program<test_group_conv>();
verify_program<test_conv_relu>(); verify_program<test_conv_relu>();
verify_program<test_conv_relu_half>(); verify_program<test_conv_relu_half>();
verify_program<test_add_relu>(); verify_program<test_add_relu>();
......
group_conv-example:

0
12"Conv*
grouptest-group_convZ
0




Z
1




b
2




B
\ No newline at end of file
...@@ -484,4 +484,15 @@ TEST_CASE(add_scalar_test) ...@@ -484,4 +484,15 @@ TEST_CASE(add_scalar_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(group_conv_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 4, 16, 16}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 1, 3, 3}});
migraphx::op::convolution op;
op.group = 4;
p.add_instruction(op, l0, l1);
auto prog = migraphx::parse_onnx("group_conv_test.onnx");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment