Commit 59504832 authored by Paul's avatar Paul
Browse files

Fix fusion errors

parent dd11be6f
...@@ -32,13 +32,20 @@ struct check_shapes ...@@ -32,13 +32,20 @@ struct check_shapes
return name + ": "; return name + ": ";
} }
const check_shapes& has(std::size_t n) const std::size_t size() const
{ {
if (begin == end)
return 0;
assert(begin != nullptr); assert(begin != nullptr);
assert(end != nullptr); assert(end != nullptr);
if(end - begin != n) return end - begin;
}
const check_shapes& has(std::size_t n) const
{
if(size() != n)
MIGRAPH_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) + MIGRAPH_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) +
" but given " + std::to_string(end - begin)); " but given " + std::to_string(size()));
return *this; return *this;
} }
...@@ -113,10 +120,10 @@ struct check_shapes ...@@ -113,10 +120,10 @@ struct check_shapes
template <class F> template <class F>
bool same(F f) const bool same(F f) const
{ {
assert(begin != nullptr);
assert(end != nullptr);
if(begin == end) if(begin == end)
return true; return true;
assert(begin != nullptr);
assert(end != nullptr);
auto&& key = f(*begin); auto&& key = f(*begin);
return this->all_of([&](const shape& s) { return f(s) == key; }); return this->all_of([&](const shape& s) { return f(s) == key; });
} }
...@@ -124,6 +131,8 @@ struct check_shapes ...@@ -124,6 +131,8 @@ struct check_shapes
template <class Predicate> template <class Predicate>
bool all_of(Predicate p) const bool all_of(Predicate p) const
{ {
if(begin == end)
return true;
assert(begin != nullptr); assert(begin != nullptr);
assert(end != nullptr); assert(end != nullptr);
return std::all_of(begin, end, p); return std::all_of(begin, end, p);
...@@ -131,6 +140,10 @@ struct check_shapes ...@@ -131,6 +140,10 @@ struct check_shapes
const shape* get(long i) const shape* get(long i)
{ {
if(i >= size())
MIGRAPH_THROW(prefix() + "Accessing shape out of bounds");
assert(begin != nullptr);
assert(end != nullptr);
if(i < 0) if(i < 0)
return end - i; return end - i;
return begin + i; return begin + i;
......
...@@ -74,6 +74,40 @@ struct fusion ...@@ -74,6 +74,40 @@ struct fusion
MIGRAPH_THROW("Creating operator failed"); MIGRAPH_THROW("Creating operator failed");
return result; return result;
} }
shape get_workspace(context&)
{
// TODO: Use zero workspace for now
std::size_t ws_size = 0;
// int algo_count = 1;
// miopenConvFwdAlgorithm_t algo;
// miopenFusionPlanConvolutionGetAlgo(fp.get(), 1, &algo_count, &algo);
// miopenFusionPlanGetWorkSpaceSize(ctx.handle.get(), fp.get(), &ws_size, algo);
return shape{shape::int8_type, {ws_size}};
}
void compile(context& ctx)
{
auto status = miopenCompileFusionPlan(ctx.handle.get(), fp.get());
if(status != miopenStatusSuccess)
MIGRAPH_THROW("Compiling fusion plan failed");
}
argument execute(context& ctx, const fused_operator_args& fargs, const argument& x, const argument& y) const
{
auto x_td = make_tensor(x.get_shape());
auto y_td = make_tensor(y.get_shape());
auto status = miopenExecuteFusionPlan(ctx.handle.get(),
fp.get(),
x_td.get(),
x.implicit(),
y_td.get(),
y.implicit(),
fargs.get());
if(status != miopenStatusSuccess)
MIGRAPH_THROW("Failed to execute fusion plan");
return y;
}
}; };
MIGRAPH_PRED_MATCHER(bias_shape, instruction_ref ins) MIGRAPH_PRED_MATCHER(bias_shape, instruction_ref ins)
...@@ -141,8 +175,8 @@ struct miopen_conv_bias ...@@ -141,8 +175,8 @@ struct miopen_conv_bias
miopen_conv_bias(op::convolution c, shape input, shape weights, shape b) : op(c), f(input) miopen_conv_bias(op::convolution c, shape input, shape weights, shape b) : op(c), f(input)
{ {
f.create_conv(op, weights); conv = f.create_conv(op, weights);
f.create_bias(b); bias = f.create_bias(b);
} }
std::string name() const { return "gpu::conv_bias"; } std::string name() const { return "gpu::conv_bias"; }
...@@ -153,35 +187,19 @@ struct miopen_conv_bias ...@@ -153,35 +187,19 @@ struct miopen_conv_bias
return op.compute_shape({inputs.at(0), inputs.at(1)}); return op.compute_shape({inputs.at(0), inputs.at(1)});
} }
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
auto fargs = make_fused_args(); auto fargs = make_fused_args();
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
auto x = make_tensor(args[0].get_shape());
auto y = make_tensor(output_shape);
miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit()); miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit());
miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit()); miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit());
miopenExecuteFusionPlan(ctx.handle.get(), return f.execute(ctx, fargs, args[0], args[4]);
f.get(),
x.get(),
args[0].implicit(),
y.get(),
args[4].implicit(),
fargs.get());
return args.at(4);
} }
shape compile(context& ctx) shape compile(context& ctx)
{ {
int algo_count = 1; f.compile(ctx);
miopenConvFwdAlgorithm_t algo; return f.get_workspace(ctx);
miopenFusionPlanConvolutionGetAlgo(f.get(), 1, &algo_count, &algo);
std::size_t ws_size = 0;
miopenFusionPlanGetWorkSpaceSize(ctx.handle.get(), f.get(), &ws_size, algo);
auto status = miopenCompileFusionPlan(ctx.handle.get(), f.get());
if(status != miopenStatusSuccess)
MIGRAPH_THROW("Compiling fusion plan failed");
return shape{shape::int8_type, {ws_size}};
} }
}; };
...@@ -191,12 +209,13 @@ struct miopen_conv_bias_relu ...@@ -191,12 +209,13 @@ struct miopen_conv_bias_relu
fusion f; fusion f;
fusion::op_t conv; fusion::op_t conv;
fusion::op_t bias; fusion::op_t bias;
fusion::op_t relu;
miopen_conv_bias_relu(op::convolution c, shape input, shape weights, shape b) : op(c), f(input) miopen_conv_bias_relu(op::convolution c, shape input, shape weights, shape b) : op(c), f(input)
{ {
f.create_conv(op, weights); conv = f.create_conv(op, weights);
f.create_bias(b); bias = f.create_bias(b);
f.create_relu(); relu = f.create_relu();
} }
std::string name() const { return "gpu::conv_bias_relu"; } std::string name() const { return "gpu::conv_bias_relu"; }
...@@ -207,35 +226,20 @@ struct miopen_conv_bias_relu ...@@ -207,35 +226,20 @@ struct miopen_conv_bias_relu
return op.compute_shape({inputs.at(0), inputs.at(1)}); return op.compute_shape({inputs.at(0), inputs.at(1)});
} }
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
auto fargs = make_fused_args(); auto fargs = make_fused_args();
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
auto x = make_tensor(args[0].get_shape());
auto y = make_tensor(output_shape);
miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit()); miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit());
miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit()); miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit());
miopenExecuteFusionPlan(ctx.handle.get(), miopenSetOpArgsActivForward(fargs.get(), relu, &alpha, &beta, 0, 0, 0);
f.get(), return f.execute(ctx, fargs, args[0], args[4]);
x.get(),
args[0].implicit(),
y.get(),
args[4].implicit(),
fargs.get());
return args.at(4);
} }
shape compile(context& ctx) shape compile(context& ctx)
{ {
int algo_count = 1; f.compile(ctx);
miopenConvFwdAlgorithm_t algo; return f.get_workspace(ctx);
miopenFusionPlanConvolutionGetAlgo(f.get(), 1, &algo_count, &algo);
std::size_t ws_size = 0;
miopenFusionPlanGetWorkSpaceSize(ctx.handle.get(), f.get(), &ws_size, algo);
auto status = miopenCompileFusionPlan(ctx.handle.get(), f.get());
if(status != miopenStatusSuccess)
MIGRAPH_THROW("Compiling fusion plan failed");
return shape{shape::int8_type, {ws_size}};
} }
}; };
...@@ -243,19 +247,12 @@ template <class... Ms> ...@@ -243,19 +247,12 @@ template <class... Ms>
auto conv_bias(Ms... ms) auto conv_bias(Ms... ms)
{ {
return match::name("gpu::add")( return match::name("gpu::add")(
match::either_arg(0, 1)(bias_shape().bind("bias"), fusable_conv().bind("conv")), ms...); match::either_arg(0, 1)(match::arg(0)(bias_shape()).bind("bias"), fusable_conv().bind("conv")), ms...);
} }
struct match_conv_bias template<class Op>
void apply_conv_bias(context& ctx, program& p, match::matcher_result r)
{ {
context* ctx = nullptr;
auto matcher() const
{
return conv_bias(match::none_of(match::output(match::name("gpu::relu"))));
}
void apply(program& p, match::matcher_result r) const
{
auto conv_ins = r.instructions["conv"]; auto conv_ins = r.instructions["conv"];
auto bias_ins = r.instructions["bias"]; auto bias_ins = r.instructions["bias"];
auto ins = r.result; auto ins = r.result;
...@@ -265,43 +262,43 @@ struct match_conv_bias ...@@ -265,43 +262,43 @@ struct match_conv_bias
auto alloc_ins = ins->inputs().back(); auto alloc_ins = ins->inputs().back();
auto old_ws_ins = conv_ins->inputs().at(2); auto old_ws_ins = conv_ins->inputs().at(2);
miopen_conv_bias cb{ Op cb{
conv_op, input_ins->get_shape(), weights_ins->get_shape(), bias_ins->get_shape()}; conv_op, input_ins->get_shape(), weights_ins->get_shape(), bias_ins->get_shape()};
// TODO: Insert ws allocation // TODO: Insert ws allocation
auto ws = cb.compile(*ctx); auto ws = cb.compile(ctx);
p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins); p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
}
struct match_conv_bias
{
context* ctx = nullptr;
auto matcher() const
{
return conv_bias(match::none_of(match::output(match::name("gpu::relu"))));
}
void apply(program& p, match::matcher_result r) const
{
apply_conv_bias<miopen_conv_bias>(*ctx, p, r);
} }
}; };
struct match_conv_bias_relu struct match_conv_bias_relu
{ {
context* ctx = nullptr; context* ctx = nullptr;
auto matcher() const { return match::name("gpu::relu")(conv_bias()); } auto matcher() const { return match::name("gpu::relu")(match::arg(0)(conv_bias())); }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
{ {
auto conv_ins = r.instructions["conv"]; apply_conv_bias<miopen_conv_bias_relu>(*ctx, p, r);
auto bias_ins = r.instructions["bias"];
auto ins = r.result;
auto input_ins = conv_ins->inputs().at(0);
auto weights_ins = conv_ins->inputs().at(1);
auto conv_op = any_cast<miopen_convolution>(conv_ins->get_operator()).op;
auto alloc_ins = ins->inputs().back();
auto old_ws_ins = conv_ins->inputs().at(2);
miopen_conv_bias_relu cbr{
conv_op, input_ins->get_shape(), weights_ins->get_shape(), bias_ins->get_shape()};
// TODO: Insert ws allocation
auto ws = cbr.compile(*ctx);
p.replace_instruction(ins, cbr, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
} }
}; };
void fuse_ops::apply(program& p) const void fuse_ops::apply(program& p) const
{ {
match::find_matches(p, match_add_relu{}, match_conv_bias_relu{ctx}, match_conv_bias{ctx}); // match::find_matches(p, match_conv_bias_relu{ctx}, match_conv_bias{ctx}, match_add_relu{});
match::find_matches(p, match_conv_bias{ctx}, match_add_relu{});
} }
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment