Commit aa0b6230 authored by Paul's avatar Paul
Browse files

Enable bn rewrite and check for literals

parent 5ff89419
...@@ -10,23 +10,28 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -10,23 +10,28 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(ins->op.name() == "batch_norm_inference") if(ins->op.name() != "batch_norm_inference")
{ continue;
auto ins_prev = ins->arguments[0]; if(not std::all_of(ins->arguments.begin()+1, ins->arguments.end(), [](auto arg) { return arg->op.name() == "@literal"; }))
if(ins_prev->op.name() == "convolution") continue;
{
auto conv_ins = ins->arguments[0];
if(conv_ins->op.name() != "convolution")
continue;
if(conv_ins->arguments[1]->op.name() != "@literal")
continue;
// Get scale, bias, mean, variance from instruction_ref // Get scale, bias, mean, variance from instruction_ref
auto gamma = ins->arguments[1]->lit; const auto& gamma = ins->arguments[1]->get_literal();
auto bias = ins->arguments[2]->lit; const auto& bias = ins->arguments[2]->get_literal();
auto mean = ins->arguments[3]->lit; const auto& mean = ins->arguments[3]->get_literal();
auto variance = ins->arguments[4]->lit; const auto& variance = ins->arguments[4]->get_literal();
// Get epsilon // Get epsilon
auto bn_op = any_cast<batch_norm_inference>(ins->op); auto bn_op = any_cast<batch_norm_inference>(ins->op);
auto epsilon = bn_op.epsilon; auto epsilon = bn_op.epsilon;
// Get convolution weights // Get convolution weights
auto weights = ins_prev->arguments[1]->lit; const auto& weights = conv_ins->arguments[1]->get_literal();
// Get convolution op // Get convolution op
auto conv_op = ins_prev->op; auto conv_op = conv_ins->op;
auto out_channels = weights.get_shape().lens()[0]; auto out_channels = weights.get_shape().lens()[0];
auto in_channels = weights.get_shape().lens()[1]; auto in_channels = weights.get_shape().lens()[1];
auto height = weights.get_shape().lens()[2]; auto height = weights.get_shape().lens()[2];
...@@ -54,10 +59,8 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -54,10 +59,8 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()}); auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()});
auto l_bias = p.add_literal({bias.get_shape(), new_bias.data()}); auto l_bias = p.add_literal({bias.get_shape(), new_bias.data()});
auto c = auto c =
p.replace_instruction(ins_prev, conv_op, {ins_prev->arguments[0], l_weights}); p.replace_instruction(conv_ins, conv_op, {conv_ins->arguments[0], l_weights});
p.replace_instruction(ins, add{}, {c, l_bias}); p.replace_instruction(ins, add{}, {c, l_bias});
} }
}
}
} }
} // namespace migraph } // namespace migraph
...@@ -115,6 +115,11 @@ struct instruction ...@@ -115,6 +115,11 @@ struct instruction
} }
shape get_shape() const { return result; } shape get_shape() const { return result; }
const literal& get_literal() const
{
assert(op.name() == "@literal");
return lit;
}
friend bool operator==(instruction_ref ref, const instruction& i) { return i == ref; } friend bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <migraph/dead_code_elimination.hpp> #include <migraph/dead_code_elimination.hpp>
#include <migraph/simplify_reshapes.hpp> #include <migraph/simplify_reshapes.hpp>
#include <migraph/eliminate_contiguous.hpp> #include <migraph/eliminate_contiguous.hpp>
#include <migraph/fwd_conv_batchnorm_rewrite.hpp>
namespace migraph { namespace migraph {
namespace gpu { namespace gpu {
...@@ -24,6 +25,8 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const ...@@ -24,6 +25,8 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
auto_contiguous{}, auto_contiguous{},
simplify_reshapes{}, simplify_reshapes{},
dead_code_elimination{}, dead_code_elimination{},
fwd_conv_batchnorm_rewrite{},
dead_code_elimination{},
lowering{ctx}, lowering{ctx},
fuse_ops{}, fuse_ops{},
dead_code_elimination{}, dead_code_elimination{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment