Unverified Commit bfbf0c27 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Skip fusing group convolutions (#531)



* Skip fusing group convolutions

* Formatting

* Fix ICE on gcc 5

* Formatting

* Fix gcc check

* Formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 90200619
......@@ -6,6 +6,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/config.hpp>
#include <unordered_map>
#include <unordered_set>
......@@ -212,11 +213,17 @@ matcher_result match_instruction(program& p, instruction_ref ins, M&& m)
return result;
}
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)
/// Find matches for an instruction in the program
template <class... Ms>
void find_matches(program& p, instruction_ref ins, Ms&&... ms)
{
bool match = false;
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const
#endif
bool trace = enabled(MIGRAPHX_TRACE_MATCHES{});
bool match = false;
each_args(
[&](auto&& m) {
if(match)
......@@ -224,6 +231,11 @@ void find_matches(program& p, instruction_ref ins, Ms&&... ms)
auto r = match_instruction(p, ins, m.matcher());
if(r.result == p.end())
return;
if(trace)
{
std::cout << "Matched by " << get_type_name(m) << std::endl;
p.debug_print(ins);
}
m.apply(p, r);
match = true;
},
......
......@@ -23,8 +23,8 @@ struct convolution
std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> dilation = {{1, 1}};
padding_mode_t padding_mode = default_;
int group = 1;
padding_mode_t padding_mode = default_;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -32,8 +32,8 @@ struct convolution
return pack(f(self.padding, "padding"),
f(self.stride, "stride"),
f(self.dilation, "dilation"),
f(self.padding_mode, "padding_mode"),
f(self.group, "group"));
f(self.group, "group"),
f(self.padding_mode, "padding_mode"));
}
std::string name() const { return "convolution"; }
......@@ -45,6 +45,9 @@ struct convolution
const shape& weights = inputs.at(1);
auto t = input.type();
if(input.lens().at(1) != (weights.lens().at(1) * group))
MIGRAPHX_THROW("CONVOLUTION: Mismatch channel numbers");
return {t,
{
input.lens()[0],
......
......@@ -613,6 +613,13 @@ struct find_conv_dot_horiz_fusion
auto&& name = (*start)->name();
if(not contains({"dot", "convolution"}, name))
return;
auto op = (*start)->get_operator();
int group = 1;
if(name == "convolution")
group = any_cast<op::convolution>(op).group;
// Skip group convolution
if(group != 1)
return;
auto input = (*start)->inputs().front();
std::vector<instruction_ref> args;
std::transform(
......@@ -628,9 +635,8 @@ struct find_conv_dot_horiz_fusion
for(auto arg : args)
p.move_instructions(arg, input);
// TODO: Check if axises match
auto concat = p.insert_instruction(input, op::concat{concat_axis}, args);
auto fused =
p.insert_instruction(std::next(input), (*start)->get_operator(), input, concat);
auto concat = p.insert_instruction(input, op::concat{concat_axis}, args);
auto fused = p.insert_instruction(std::next(input), op, input, concat);
int64_t offset = 0;
for(auto arg : range(start, last))
{
......
......@@ -1137,7 +1137,28 @@ TEST_CASE(simplify_conv_horiz)
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(simplify_conv_horiz_groups)
TEST_CASE(simplify_group_conv_horiz)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {1, 32, 111, 111}};
auto ws = migraphx::shape{migraphx::shape::int32_type, {32, 1, 7, 7}};
migraphx::program p1;
{
auto x = p1.add_parameter("x", s);
auto w1 = p1.add_literal(migraphx::generate_literal(ws, 1));
auto w2 = p1.add_literal(migraphx::generate_literal(ws, 2));
auto conv1 =
p1.add_instruction(migraphx::op::convolution{{3, 3}, {2, 2}, {1, 1}, 32}, x, w1);
auto conv2 =
p1.add_instruction(migraphx::op::convolution{{3, 3}, {2, 2}, {1, 1}, 32}, x, w2);
p1.add_instruction(pass_op{}, conv1, conv2);
}
migraphx::program p2 = p1;
run_pass(p1);
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(simplify_conv_horiz_grouped)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}};
......@@ -1184,7 +1205,7 @@ TEST_CASE(simplify_conv_horiz_groups)
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(simplify_conv_horiz_groups_extra1)
TEST_CASE(simplify_conv_horiz_grouped_extra1)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}};
......@@ -1238,7 +1259,7 @@ TEST_CASE(simplify_conv_horiz_groups_extra1)
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(simplify_conv_horiz_groups_extra2)
TEST_CASE(simplify_conv_horiz_grouped_extra2)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment