Unverified Commit 534a05c1 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Handle miopen fusions when using pointwise fusions (#1019)

* Add matcher for conv_bias pointwise
* Add fusion op
parent 594f2802
...@@ -62,6 +62,8 @@ struct fusion ...@@ -62,6 +62,8 @@ struct fusion
keep_alive(std::move(t)); keep_alive(std::move(t));
} }
bool empty() const { return fp == nullptr; }
op_t operator[](std::size_t i) const op_t operator[](std::size_t i) const
{ {
assert(fp); assert(fp);
...@@ -125,12 +127,11 @@ struct fusion ...@@ -125,12 +127,11 @@ struct fusion
return shape{shape::int8_type, {ws_size}}; return shape{shape::int8_type, {ws_size}};
} }
void compile(context& ctx) bool compile(context& ctx)
{ {
assert(fp); assert(fp);
auto status = miopenCompileFusionPlan(ctx.get_stream().get_miopen(), fp.get()); return miopenCompileFusionPlan(ctx.get_stream().get_miopen(), fp.get()) ==
if(status != miopenStatusSuccess) miopenStatusSuccess;
MIGRAPHX_THROW("Compiling fusion plan failed");
} }
argument execute(context& ctx, argument execute(context& ctx,
...@@ -561,6 +562,117 @@ struct find_mul_add_relu ...@@ -561,6 +562,117 @@ struct find_mul_add_relu
} }
}; };
struct miopen_fusion
{
struct fuse_op_data
{
operation op;
float alpha = 1;
float beta = 0;
};
struct fuse_op : fuse_op_data, reflect_equality<fuse_op>, reflect_stream<fuse_op>
{
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"), f(self.alpha, "alpha"), f(self.beta, "beta"));
}
};
std::vector<fuse_op> ops = {};
fusion f = {};
std::function<void(context&, const fusion&, const std::vector<argument>&)> execute;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.ops, "ops"));
}
value compile(context& ctx, const shape&, std::vector<shape> inputs)
{
// Compensate for allocation
inputs.pop_back();
std::size_t i = 0;
f = fusion(inputs[i]);
i++;
std::vector<std::function<void(const fused_operator_args&, const std::vector<argument>&)>>
invokers;
for(auto&& fop : ops)
{
if(i > inputs.size())
{
f = {};
return {};
}
if(fop.op.name() == "convolution")
{
auto* mop = f.create_conv(any_cast<op::convolution>(fop.op), inputs[i]);
invokers.push_back(
[=](const fused_operator_args& fargs, const std::vector<argument>& args) {
miopenSetOpArgsConvForward(
fargs.get(), mop, &fop.alpha, &fop.beta, args[i].implicit());
});
i++;
}
else if(fop.op.name() == "add")
{
auto* mop = f.create_bias(inputs[i]);
invokers.push_back(
[=](const fused_operator_args& fargs, const std::vector<argument>& args) {
miopenSetOpArgsBiasForward(
fargs.get(), mop, &fop.alpha, &fop.beta, args[i].implicit());
});
i++;
}
else if(fop.op.name() == "relu")
{
auto* mop = f.create_relu();
invokers.push_back([=](const fused_operator_args& fargs,
const std::vector<argument>&) {
miopenSetOpArgsActivForward(fargs.get(), mop, &fop.alpha, &fop.beta, 0, 0, 0);
});
}
else
{
f = {};
return {};
}
}
if(not f.compile(ctx))
{
f = {};
return {};
}
execute = [invokers](context& c, const fusion& ff, const std::vector<argument>& args) {
auto fargs = make_fused_args();
for(auto&& invoker : invokers)
invoker(fargs, args);
ff.execute(c, fargs, args.front(), args.back());
};
return {{"workspace", f.get_workspace(ctx).bytes()}};
}
void finalize(context& ctx, const shape& output_shape, const std::vector<shape>& inputs)
{
if(not f.empty())
return;
auto v = compile(ctx, output_shape, inputs);
if(not v.is_object())
MIGRAPHX_THROW("Failed to compile fusion plan");
}
std::string name() const { return "gpu::miopen_fusion"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
if(ops.empty())
return {};
// TODO: Check number of arguments
return ops.front().op.compute_shape({inputs[0], inputs[1]});
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
execute(ctx, f, args);
return args.back();
}
};
struct miopen_conv_bias struct miopen_conv_bias
{ {
op::convolution op; op::convolution op;
...@@ -596,7 +708,8 @@ struct miopen_conv_bias ...@@ -596,7 +708,8 @@ struct miopen_conv_bias
f = fusion(inputs[0]); f = fusion(inputs[0]);
conv = f.create_conv(op, inputs[1]); conv = f.create_conv(op, inputs[1]);
bias = f.create_bias(inputs[3]); bias = f.create_bias(inputs[3]);
f.compile(ctx); if(not f.compile(ctx))
MIGRAPHX_THROW("Failed to compile fusion plan");
} }
shape get_workspace(context& ctx) { return f.get_workspace(ctx); } shape get_workspace(context& ctx) { return f.get_workspace(ctx); }
...@@ -683,6 +796,25 @@ void apply_conv_bias(context& ctx, module& p, match::matcher_result r) ...@@ -683,6 +796,25 @@ void apply_conv_bias(context& ctx, module& p, match::matcher_result r)
p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins); p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
} }
inline auto precompile_name(std::string s) // NOLINT
{
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(ins->name() != "gpu::precompile_op")
return false;
auto op = from_value<operation>(ins->get_operator().to_value().at("op"));
return (op.name() == s);
});
}
template <class... Ms>
auto conv_bias_pointwise(Ms... ms)
{
return precompile_name("pointwise")(
match::either_arg(0, 1)(bias_shape(match::used_once()).bind("bias"),
fusable_conv(match::used_once()).bind("conv")),
ms...);
}
struct find_conv_bias struct find_conv_bias
{ {
context* ctx = nullptr; context* ctx = nullptr;
...@@ -709,6 +841,46 @@ struct find_conv_bias_relu ...@@ -709,6 +841,46 @@ struct find_conv_bias_relu
} }
}; };
struct find_conv_pointwise
{
context* ctx = nullptr;
auto matcher() const
{
return precompile_name("pointwise")(
match::nargs(3),
match::either_arg(0, 1)(bias_shape(match::used_once()).bind("bias"),
fusable_conv(match::used_once()).bind("conv")));
}
void apply(module& m, match::matcher_result r) const
{
auto conv_ins = r.instructions["conv"];
auto bias_ins = r.instructions["bias"];
auto ins = r.result;
auto input_ins = conv_ins->inputs().at(0);
auto weights_ins = conv_ins->inputs().at(1);
auto conv_op = any_cast<miopen_convolution>(conv_ins->get_operator()).op;
auto alloc_ins = ins->inputs().back();
module_ref pm = ins->module_inputs().front();
miopen_fusion op{};
op.ops.push_back({{conv_op}});
for(auto&& i : *pm)
{
if(i.name()[0] == '@')
continue;
auto inputs = to_shapes(i.inputs());
op.ops.push_back({{i.get_operator()}});
}
std::vector<instruction_ref> inputs = {input_ins, weights_ins, bias_ins, alloc_ins};
auto v = op.compile(*ctx, ins->get_shape(), to_shapes(inputs));
if(not v.is_object())
return;
m.replace_instruction(ins, op, inputs);
}
};
struct find_gemm_add struct find_gemm_add
{ {
auto matcher() const auto matcher() const
...@@ -778,6 +950,7 @@ void fuse_ops::apply(module& p) const ...@@ -778,6 +950,7 @@ void fuse_ops::apply(module& p) const
match::find_matches(p, find_triadd{}); match::find_matches(p, find_triadd{});
match::find_matches(p, match::find_matches(p,
find_layernorm{}, find_layernorm{},
find_conv_pointwise{ctx},
find_conv_bias_relu{ctx}, find_conv_bias_relu{ctx},
find_conv_bias{ctx}, find_conv_bias{ctx},
find_add_gelu{}, find_add_gelu{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment