Commit 41e236d3 authored by Paul's avatar Paul
Browse files

Add matcher for fusable conv

parent 7dbc89d3
...@@ -76,7 +76,7 @@ struct bindable_matcher ...@@ -76,7 +76,7 @@ struct bindable_matcher
{ {
M m; M m;
auto bind(std::string name) { return bind_match(m, std::move(name)); } auto bind(std::string name) const { return bind_match(m, std::move(name)); }
instruction_ref match(matcher_context& ctx, instruction_ref ins) const instruction_ref match(matcher_context& ctx, instruction_ref ins) const
{ {
...@@ -137,7 +137,7 @@ struct basic_matcher ...@@ -137,7 +137,7 @@ struct basic_matcher
}); });
} }
auto bind(std::string name) { return bind_match(m, name); } auto bind(std::string name) const { return bind_match(m, std::move(name)); }
instruction_ref match(matcher_context& ctx, instruction_ref ins) const instruction_ref match(matcher_context& ctx, instruction_ref ins) const
{ {
...@@ -181,7 +181,7 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p) ...@@ -181,7 +181,7 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
{ \ { \
bool operator()(__VA_ARGS__) const; \ bool operator()(__VA_ARGS__) const; \
}; \ }; \
const constexpr auto name = migraph::match::basic_matcher<predicate_matcher<name##_m>>{{}}; \ const constexpr auto name = migraph::match::basic_matcher<migraph::match::predicate_matcher<name##_m>>{{}}; \
inline bool name##_m::operator()(__VA_ARGS__) const inline bool name##_m::operator()(__VA_ARGS__) const
struct matcher_result struct matcher_result
...@@ -310,7 +310,7 @@ auto args(Ms... ms) ...@@ -310,7 +310,7 @@ auto args(Ms... ms)
}); });
} }
auto either_arg(std::size_t i, std::size_t j) inline auto either_arg(std::size_t i, std::size_t j)
{ {
return [=](auto m1, auto m2) { return [=](auto m1, auto m2) {
return match::any_of(match::all_of(arg(i)(m1), arg(j)(m2)), return match::any_of(match::all_of(arg(i)(m1), arg(j)(m2)),
......
...@@ -76,6 +76,34 @@ struct fusion ...@@ -76,6 +76,34 @@ struct fusion
} }
}; };
MIGRAPH_PRED_MATCHER(bias_shape, instruction_ref ins)
{
auto&& s = ins->get_shape();
return s.broadcasted() and
s.strides().size() == 4 and
s.strides()[0] == 0 and
s.strides()[1] != 0 and
s.strides()[2] == 0 and
s.strides()[3] == 0;
}
// TODO: Move to another header
template<class T, class... Ts>
std::array<T, sizeof...(Ts)+1> make_array(T x, Ts... xs)
{
return {std::move(x), std::move(static_cast<T>(xs))...};
}
MIGRAPH_PRED_MATCHER(fusable_conv, instruction_ref ins)
{
if(ins->name() != "gpu::convolution")
return false;
auto op = any_cast<miopen_convolution>(ins->get_operator()).op;
return op.padding == make_array<size_t>(0, 0) and
op.stride == make_array<size_t>(1, 1) and
op.dilation == make_array<size_t>(1, 1);
}
struct hip_add_relu struct hip_add_relu
{ {
std::string name() const { return "hip::add_relu"; } std::string name() const { return "hip::add_relu"; }
...@@ -168,17 +196,17 @@ struct match_conv_bias ...@@ -168,17 +196,17 @@ struct match_conv_bias
auto matcher() const auto matcher() const
{ {
return match::name("gpu::add")(match::either_arg(0, 1)( return match::name("gpu::add")(match::either_arg(0, 1)(
match::broadcast_shape().bind("bias"), match::name("gpu::convolution").bind("conv"))); bias_shape().bind("bias"), fusable_conv().bind("conv")));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
{ {
auto conv_ins = r.instructions["conv"]; auto conv_ins = r.instructions["conv"];
auto bias_ins = r.instructions["bias"]; auto bias_ins = r.instructions["bias"];
auto ins = r.result;
auto input_ins = conv_ins->inputs().at(0); auto input_ins = conv_ins->inputs().at(0);
auto weights_ins = conv_ins->inputs().at(1); auto weights_ins = conv_ins->inputs().at(1);
auto conv_op = any_cast<miopen_convolution>(conv_ins->get_operator()).op; auto conv_op = any_cast<miopen_convolution>(conv_ins->get_operator()).op;
auto ins = r.result;
auto alloc_ins = ins->inputs().back(); auto alloc_ins = ins->inputs().back();
auto old_ws_ins = conv_ins->inputs().at(2); auto old_ws_ins = conv_ins->inputs().at(2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment