Commit 980fc710 authored by Paul's avatar Paul
Browse files

Formatting

parent 59504832
......@@ -34,7 +34,7 @@ struct check_shapes
std::size_t size() const
{
if (begin == end)
if(begin == end)
return 0;
assert(begin != nullptr);
assert(end != nullptr);
......
......@@ -93,7 +93,10 @@ struct fusion
MIGRAPH_THROW("Compiling fusion plan failed");
}
argument execute(context& ctx, const fused_operator_args& fargs, const argument& x, const argument& y) const
argument execute(context& ctx,
const fused_operator_args& fargs,
const argument& x,
const argument& y) const
{
auto x_td = make_tensor(x.get_shape());
auto y_td = make_tensor(y.get_shape());
......@@ -186,8 +189,7 @@ struct miopen_conv_bias
// TODO: Check slices
return op.compute_shape({inputs.at(0), inputs.at(1)});
}
argument
compute(context& ctx, const shape&, const std::vector<argument>& args) const
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
auto fargs = make_fused_args();
float alpha = 1, beta = 0;
......@@ -225,8 +227,7 @@ struct miopen_conv_bias_relu
// TODO: Check slices
return op.compute_shape({inputs.at(0), inputs.at(1)});
}
argument
compute(context& ctx, const shape&, const std::vector<argument>& args) const
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
auto fargs = make_fused_args();
float alpha = 1, beta = 0;
......@@ -246,11 +247,12 @@ struct miopen_conv_bias_relu
template <class... Ms>
auto conv_bias(Ms... ms)
{
return match::name("gpu::add")(
match::either_arg(0, 1)(match::arg(0)(bias_shape()).bind("bias"), fusable_conv().bind("conv")), ms...);
return match::name("gpu::add")(match::either_arg(0, 1)(match::arg(0)(bias_shape()).bind("bias"),
fusable_conv().bind("conv")),
ms...);
}
template<class Op>
template <class Op>
void apply_conv_bias(context& ctx, program& p, match::matcher_result r)
{
auto conv_ins = r.instructions["conv"];
......@@ -262,8 +264,7 @@ void apply_conv_bias(context& ctx, program& p, match::matcher_result r)
auto alloc_ins = ins->inputs().back();
auto old_ws_ins = conv_ins->inputs().at(2);
Op cb{
conv_op, input_ins->get_shape(), weights_ins->get_shape(), bias_ins->get_shape()};
Op cb{conv_op, input_ins->get_shape(), weights_ins->get_shape(), bias_ins->get_shape()};
// TODO: Insert ws allocation
auto ws = cb.compile(ctx);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment