Commit a4178130 authored by Paul's avatar Paul
Browse files

Enable miopen fusions

parent acd4e0fa
...@@ -131,7 +131,15 @@ MIGRAPH_PRED_MATCHER(fusable_conv, instruction_ref ins) ...@@ -131,7 +131,15 @@ MIGRAPH_PRED_MATCHER(fusable_conv, instruction_ref ins)
{ {
if(ins->name() != "gpu::convolution") if(ins->name() != "gpu::convolution")
return false; return false;
auto op = any_cast<miopen_convolution>(ins->get_operator()).op; auto wei = ins->inputs().at(1)->get_shape();
assert(wei.lens().size() == 4);
auto channels = wei.lens()[1] * wei.lens()[0];
if(wei.lens()[0] > 64 and channels > 32768)
return false;
auto conv = any_cast<miopen_convolution>(ins->get_operator());
if(conv.algo == miopenConvolutionFwdAlgoWinograd)
return false;
auto op = conv.op;
return op.padding == make_array<size_t>(0, 0) and op.stride == make_array<size_t>(1, 1) and return op.padding == make_array<size_t>(0, 0) and op.stride == make_array<size_t>(1, 1) and
op.dilation == make_array<size_t>(1, 1); op.dilation == make_array<size_t>(1, 1);
} }
...@@ -366,10 +374,10 @@ struct match_conv_bias_relu ...@@ -366,10 +374,10 @@ struct match_conv_bias_relu
void fuse_ops::apply(program& p) const void fuse_ops::apply(program& p) const
{ {
// clang-format off // clang-format off
match::find_matches(p, match_triadd{});
match::find_matches(p, match::find_matches(p,
match_triadd{}, match_conv_bias_relu{ctx},
// match_conv_bias_relu{ctx}, match_conv_bias{ctx},
// match_conv_bias{ctx},
match_add_relu{} match_add_relu{}
); );
// clang-format on // clang-format on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment