Commit 056318a0 authored by Paul's avatar Paul
Browse files

Skip group convolution in mlir

parent da0dccc7
......@@ -61,13 +61,25 @@ struct mlir_conv
MIGRAPHX_REGISTER_OP(mlir_conv);
namespace {
MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
{
if (ins->name() != "convolution")
return false;
value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>();
if (group != 1)
return false;
return true;
}
struct find_conv_pointwise
{
// Find a convolution followed by a pointwise operation.
auto matcher() const
{
auto convolution =
match::skip(match::name("contiguous"))(match::name("convolution").bind("convolution"));
match::skip(match::name("contiguous"))(is_mlir_conv().bind("convolution"));
return match::name("pointwise")(match::any_of[match::inputs()](convolution.bind("x")));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment