Commit 980fc710 authored by Paul's avatar Paul
Browse files

Formatting

parent 59504832
...@@ -34,7 +34,7 @@ struct check_shapes ...@@ -34,7 +34,7 @@ struct check_shapes
std::size_t size() const std::size_t size() const
{ {
if (begin == end) if(begin == end)
return 0; return 0;
assert(begin != nullptr); assert(begin != nullptr);
assert(end != nullptr); assert(end != nullptr);
......
...@@ -93,7 +93,10 @@ struct fusion ...@@ -93,7 +93,10 @@ struct fusion
MIGRAPH_THROW("Compiling fusion plan failed"); MIGRAPH_THROW("Compiling fusion plan failed");
} }
argument execute(context& ctx, const fused_operator_args& fargs, const argument& x, const argument& y) const argument execute(context& ctx,
const fused_operator_args& fargs,
const argument& x,
const argument& y) const
{ {
auto x_td = make_tensor(x.get_shape()); auto x_td = make_tensor(x.get_shape());
auto y_td = make_tensor(y.get_shape()); auto y_td = make_tensor(y.get_shape());
...@@ -186,8 +189,7 @@ struct miopen_conv_bias ...@@ -186,8 +189,7 @@ struct miopen_conv_bias
// TODO: Check slices // TODO: Check slices
return op.compute_shape({inputs.at(0), inputs.at(1)}); return op.compute_shape({inputs.at(0), inputs.at(1)});
} }
argument argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
auto fargs = make_fused_args(); auto fargs = make_fused_args();
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
...@@ -225,8 +227,7 @@ struct miopen_conv_bias_relu ...@@ -225,8 +227,7 @@ struct miopen_conv_bias_relu
// TODO: Check slices // TODO: Check slices
return op.compute_shape({inputs.at(0), inputs.at(1)}); return op.compute_shape({inputs.at(0), inputs.at(1)});
} }
argument argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
auto fargs = make_fused_args(); auto fargs = make_fused_args();
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
...@@ -246,11 +247,12 @@ struct miopen_conv_bias_relu ...@@ -246,11 +247,12 @@ struct miopen_conv_bias_relu
template <class... Ms> template <class... Ms>
auto conv_bias(Ms... ms) auto conv_bias(Ms... ms)
{ {
return match::name("gpu::add")( return match::name("gpu::add")(match::either_arg(0, 1)(match::arg(0)(bias_shape()).bind("bias"),
match::either_arg(0, 1)(match::arg(0)(bias_shape()).bind("bias"), fusable_conv().bind("conv")), ms...); fusable_conv().bind("conv")),
ms...);
} }
template<class Op> template <class Op>
void apply_conv_bias(context& ctx, program& p, match::matcher_result r) void apply_conv_bias(context& ctx, program& p, match::matcher_result r)
{ {
auto conv_ins = r.instructions["conv"]; auto conv_ins = r.instructions["conv"];
...@@ -262,8 +264,7 @@ void apply_conv_bias(context& ctx, program& p, match::matcher_result r) ...@@ -262,8 +264,7 @@ void apply_conv_bias(context& ctx, program& p, match::matcher_result r)
auto alloc_ins = ins->inputs().back(); auto alloc_ins = ins->inputs().back();
auto old_ws_ins = conv_ins->inputs().at(2); auto old_ws_ins = conv_ins->inputs().at(2);
Op cb{ Op cb{conv_op, input_ins->get_shape(), weights_ins->get_shape(), bias_ins->get_shape()};
conv_op, input_ins->get_shape(), weights_ins->get_shape(), bias_ins->get_shape()};
// TODO: Insert ws allocation // TODO: Insert ws allocation
auto ws = cb.compile(ctx); auto ws = cb.compile(ctx);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment