Commit b987e44d authored by Paul's avatar Paul
Browse files

Add cbr

parent 8d6769b6
...@@ -270,6 +270,13 @@ MIGRAPH_PRED_MATCHER(broadcast_shape, instruction_ref ins) ...@@ -270,6 +270,13 @@ MIGRAPH_PRED_MATCHER(broadcast_shape, instruction_ref ins)
return ins->get_shape().broadcasted(); return ins->get_shape().broadcasted();
} }
MIGRAPH_BASIC_MATCHER(output, matcher_context& ctx, instruction_ref ins)
{
if(ins->outputs().size() == 1)
return ins->outputs().front();
return ctx.not_found();
}
inline auto name(std::string name) inline auto name(std::string name)
{ {
return make_basic_pred_matcher( return make_basic_pred_matcher(
......
...@@ -185,13 +185,73 @@ struct miopen_conv_bias ...@@ -185,13 +185,73 @@ struct miopen_conv_bias
} }
}; };
struct miopen_conv_bias_relu
{
op::convolution op;
fusion f;
fusion::op_t conv;
fusion::op_t bias;
miopen_conv_bias_relu(op::convolution c, shape input, shape weights, shape b) : op(c), f(input)
{
f.create_conv(op, weights);
f.create_bias(b);
f.create_relu();
}
std::string name() const { return "gpu::conv_bias_relu"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(5);
// TODO: Check slices
return op.compute_shape({inputs.at(0), inputs.at(1)});
}
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{
auto fargs = make_fused_args();
float alpha = 1, beta = 0;
auto x = make_tensor(args[0].get_shape());
auto y = make_tensor(output_shape);
miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit());
miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit());
miopenExecuteFusionPlan(ctx.handle.get(),
f.get(),
x.get(),
args[0].implicit(),
y.get(),
args[4].implicit(),
fargs.get());
return args.at(4);
}
shape compile(context& ctx)
{
int algo_count = 1;
miopenConvFwdAlgorithm_t algo;
miopenFusionPlanConvolutionGetAlgo(f.get(), 1, &algo_count, &algo);
std::size_t ws_size = 0;
miopenFusionPlanGetWorkSpaceSize(ctx.handle.get(), f.get(), &ws_size, algo);
auto status = miopenCompileFusionPlan(ctx.handle.get(), f.get());
if(status != miopenStatusSuccess)
MIGRAPH_THROW("Compiling fusion plan failed");
return shape{shape::int8_type, {ws_size}};
}
};
template<class... Ms>
auto conv_bias(Ms... ms)
{
return match::name("gpu::add")(
match::either_arg(0, 1)(bias_shape().bind("bias"), fusable_conv().bind("conv")), ms...);
}
struct match_conv_bias struct match_conv_bias
{ {
context* ctx = nullptr; context* ctx = nullptr;
auto matcher() const auto matcher() const
{ {
return match::name("gpu::add")( return conv_bias(match::none_of(match::output(match::name("gpu::relu"))));
match::either_arg(0, 1)(bias_shape().bind("bias"), fusable_conv().bind("conv")));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
...@@ -214,9 +274,37 @@ struct match_conv_bias ...@@ -214,9 +274,37 @@ struct match_conv_bias
} }
}; };
struct match_conv_bias_relu
{
context* ctx = nullptr;
auto matcher() const
{
return match::name("gpu::relu")(conv_bias());
}
void apply(program& p, match::matcher_result r) const
{
auto conv_ins = r.instructions["conv"];
auto bias_ins = r.instructions["bias"];
auto ins = r.result;
auto input_ins = conv_ins->inputs().at(0);
auto weights_ins = conv_ins->inputs().at(1);
auto conv_op = any_cast<miopen_convolution>(conv_ins->get_operator()).op;
auto alloc_ins = ins->inputs().back();
auto old_ws_ins = conv_ins->inputs().at(2);
miopen_conv_bias_relu cbr{
conv_op, input_ins->get_shape(), weights_ins->get_shape(), bias_ins->get_shape()};
// TODO: Insert ws allocation
auto ws = cbr.compile(*ctx);
p.replace_instruction(ins, cbr, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
}
};
void fuse_ops::apply(program& p) const void fuse_ops::apply(program& p) const
{ {
match::find_matches(p, match_add_relu{}, match_conv_bias{ctx}); match::find_matches(p, match_add_relu{}, match_conv_bias_relu{ctx}, match_conv_bias{ctx});
} }
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment