Unverified Commit bfbf0c27 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Skip fusing group convolutions (#531)



* Skip fusing group convolutions

* Formatting

* Fix ICE on gcc 5

* Formatting

* Fix gcc check

* Formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 90200619
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
...@@ -212,11 +213,17 @@ matcher_result match_instruction(program& p, instruction_ref ins, M&& m) ...@@ -212,11 +213,17 @@ matcher_result match_instruction(program& p, instruction_ref ins, M&& m)
return result; return result;
} }
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)
/// Find matches for an instruction in the program /// Find matches for an instruction in the program
template <class... Ms> template <class... Ms>
void find_matches(program& p, instruction_ref ins, Ms&&... ms) void find_matches(program& p, instruction_ref ins, Ms&&... ms)
{ {
bool match = false; #if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const
#endif
bool trace = enabled(MIGRAPHX_TRACE_MATCHES{});
bool match = false;
each_args( each_args(
[&](auto&& m) { [&](auto&& m) {
if(match) if(match)
...@@ -224,6 +231,11 @@ void find_matches(program& p, instruction_ref ins, Ms&&... ms) ...@@ -224,6 +231,11 @@ void find_matches(program& p, instruction_ref ins, Ms&&... ms)
auto r = match_instruction(p, ins, m.matcher()); auto r = match_instruction(p, ins, m.matcher());
if(r.result == p.end()) if(r.result == p.end())
return; return;
if(trace)
{
std::cout << "Matched by " << get_type_name(m) << std::endl;
p.debug_print(ins);
}
m.apply(p, r); m.apply(p, r);
match = true; match = true;
}, },
......
...@@ -23,8 +23,8 @@ struct convolution ...@@ -23,8 +23,8 @@ struct convolution
std::array<std::size_t, 2> stride = {{1, 1}}; std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> dilation = {{1, 1}}; std::array<std::size_t, 2> dilation = {{1, 1}};
padding_mode_t padding_mode = default_;
int group = 1; int group = 1;
padding_mode_t padding_mode = default_;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -32,8 +32,8 @@ struct convolution ...@@ -32,8 +32,8 @@ struct convolution
return pack(f(self.padding, "padding"), return pack(f(self.padding, "padding"),
f(self.stride, "stride"), f(self.stride, "stride"),
f(self.dilation, "dilation"), f(self.dilation, "dilation"),
f(self.padding_mode, "padding_mode"), f(self.group, "group"),
f(self.group, "group")); f(self.padding_mode, "padding_mode"));
} }
std::string name() const { return "convolution"; } std::string name() const { return "convolution"; }
...@@ -45,6 +45,9 @@ struct convolution ...@@ -45,6 +45,9 @@ struct convolution
const shape& weights = inputs.at(1); const shape& weights = inputs.at(1);
auto t = input.type(); auto t = input.type();
if(input.lens().at(1) != (weights.lens().at(1) * group))
MIGRAPHX_THROW("CONVOLUTION: Mismatch channel numbers");
return {t, return {t,
{ {
input.lens()[0], input.lens()[0],
......
...@@ -613,6 +613,13 @@ struct find_conv_dot_horiz_fusion ...@@ -613,6 +613,13 @@ struct find_conv_dot_horiz_fusion
auto&& name = (*start)->name(); auto&& name = (*start)->name();
if(not contains({"dot", "convolution"}, name)) if(not contains({"dot", "convolution"}, name))
return; return;
auto op = (*start)->get_operator();
int group = 1;
if(name == "convolution")
group = any_cast<op::convolution>(op).group;
// Skip group convolution
if(group != 1)
return;
auto input = (*start)->inputs().front(); auto input = (*start)->inputs().front();
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
std::transform( std::transform(
...@@ -628,9 +635,8 @@ struct find_conv_dot_horiz_fusion ...@@ -628,9 +635,8 @@ struct find_conv_dot_horiz_fusion
for(auto arg : args) for(auto arg : args)
p.move_instructions(arg, input); p.move_instructions(arg, input);
// TODO: Check if axises match // TODO: Check if axises match
auto concat = p.insert_instruction(input, op::concat{concat_axis}, args); auto concat = p.insert_instruction(input, op::concat{concat_axis}, args);
auto fused = auto fused = p.insert_instruction(std::next(input), op, input, concat);
p.insert_instruction(std::next(input), (*start)->get_operator(), input, concat);
int64_t offset = 0; int64_t offset = 0;
for(auto arg : range(start, last)) for(auto arg : range(start, last))
{ {
......
...@@ -1137,7 +1137,28 @@ TEST_CASE(simplify_conv_horiz) ...@@ -1137,7 +1137,28 @@ TEST_CASE(simplify_conv_horiz)
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
} }
TEST_CASE(simplify_conv_horiz_groups) TEST_CASE(simplify_group_conv_horiz)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {1, 32, 111, 111}};
auto ws = migraphx::shape{migraphx::shape::int32_type, {32, 1, 7, 7}};
migraphx::program p1;
{
auto x = p1.add_parameter("x", s);
auto w1 = p1.add_literal(migraphx::generate_literal(ws, 1));
auto w2 = p1.add_literal(migraphx::generate_literal(ws, 2));
auto conv1 =
p1.add_instruction(migraphx::op::convolution{{3, 3}, {2, 2}, {1, 1}, 32}, x, w1);
auto conv2 =
p1.add_instruction(migraphx::op::convolution{{3, 3}, {2, 2}, {1, 1}, 32}, x, w2);
p1.add_instruction(pass_op{}, conv1, conv2);
}
migraphx::program p2 = p1;
run_pass(p1);
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(simplify_conv_horiz_grouped)
{ {
auto s = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}}; auto s = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}}; auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}};
...@@ -1184,7 +1205,7 @@ TEST_CASE(simplify_conv_horiz_groups) ...@@ -1184,7 +1205,7 @@ TEST_CASE(simplify_conv_horiz_groups)
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
} }
TEST_CASE(simplify_conv_horiz_groups_extra1) TEST_CASE(simplify_conv_horiz_grouped_extra1)
{ {
auto s = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}}; auto s = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}}; auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}};
...@@ -1238,7 +1259,7 @@ TEST_CASE(simplify_conv_horiz_groups_extra1) ...@@ -1238,7 +1259,7 @@ TEST_CASE(simplify_conv_horiz_groups_extra1)
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
} }
TEST_CASE(simplify_conv_horiz_groups_extra2) TEST_CASE(simplify_conv_horiz_grouped_extra2)
{ {
auto s = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}}; auto s = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}}; auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment