Commit 8d6769b6 authored by Paul's avatar Paul
Browse files

Formatting

parent 41e236d3
......@@ -176,12 +176,13 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
inline instruction_ref name##_m::match(__VA_ARGS__) const
/// This macro takes care of the boilerplate for defining a predicate matcher
#define MIGRAPH_PRED_MATCHER(name, ...) \
struct name##_m \
{ \
bool operator()(__VA_ARGS__) const; \
}; \
const constexpr auto name = migraph::match::basic_matcher<migraph::match::predicate_matcher<name##_m>>{{}}; \
#define MIGRAPH_PRED_MATCHER(name, ...) \
struct name##_m \
{ \
bool operator()(__VA_ARGS__) const; \
}; \
const constexpr auto name = \
migraph::match::basic_matcher<migraph::match::predicate_matcher<name##_m>>{{}}; \
inline bool name##_m::operator()(__VA_ARGS__) const
struct matcher_result
......
......@@ -79,29 +79,24 @@ struct fusion
MIGRAPH_PRED_MATCHER(bias_shape, instruction_ref ins)
{
auto&& s = ins->get_shape();
return s.broadcasted() and
s.strides().size() == 4 and
s.strides()[0] == 0 and
s.strides()[1] != 0 and
s.strides()[2] == 0 and
s.strides()[3] == 0;
return s.broadcasted() and s.strides().size() == 4 and s.strides()[0] == 0 and
s.strides()[1] != 0 and s.strides()[2] == 0 and s.strides()[3] == 0;
}
// TODO: Move to another header
template<class T, class... Ts>
std::array<T, sizeof...(Ts)+1> make_array(T x, Ts... xs)
template <class T, class... Ts>
std::array<T, sizeof...(Ts) + 1> make_array(T x, Ts... xs)
{
return {std::move(x), std::move(static_cast<T>(xs))...};
}
}
MIGRAPH_PRED_MATCHER(fusable_conv, instruction_ref ins)
{
if(ins->name() != "gpu::convolution")
return false;
auto op = any_cast<miopen_convolution>(ins->get_operator()).op;
return op.padding == make_array<size_t>(0, 0) and
op.stride == make_array<size_t>(1, 1) and
op.dilation == make_array<size_t>(1, 1);
return op.padding == make_array<size_t>(0, 0) and op.stride == make_array<size_t>(1, 1) and
op.dilation == make_array<size_t>(1, 1);
}
struct hip_add_relu
......@@ -195,8 +190,8 @@ struct match_conv_bias
context* ctx = nullptr;
auto matcher() const
{
return match::name("gpu::add")(match::either_arg(0, 1)(
bias_shape().bind("bias"), fusable_conv().bind("conv")));
return match::name("gpu::add")(
match::either_arg(0, 1)(bias_shape().bind("bias"), fusable_conv().bind("conv")));
}
void apply(program& p, match::matcher_result r) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment