Commit 7f37c6ab authored by Paul's avatar Paul
Browse files

Fuse conv and bias

parent 740f5ba1
...@@ -8,13 +8,18 @@ namespace migraph { ...@@ -8,13 +8,18 @@ namespace migraph {
struct check_shapes struct check_shapes
{ {
const std::vector<shape>* shapes; const shape* begin;
const shape* end;
const std::string name; const std::string name;
check_shapes(const std::vector<shape>& s) : shapes(&s) {} check_shapes(const shape* b, const shape* e, const std::string& n)
: begin(b), end(e), name(n)
{}
check_shapes(const std::vector<shape>& s) : begin(s.data()), end(s.data()+s.size()) {}
template <class Op> template <class Op>
check_shapes(const std::vector<shape>& s, const Op& op) : shapes(&s), name(op.name()) check_shapes(const std::vector<shape>& s, const Op& op) : begin(s.data()), end(s.data()+s.size()), name(op.name())
{ {
} }
...@@ -28,19 +33,21 @@ struct check_shapes ...@@ -28,19 +33,21 @@ struct check_shapes
const check_shapes& has(std::size_t n) const const check_shapes& has(std::size_t n) const
{ {
assert(shapes != nullptr); assert(begin != nullptr);
if(shapes->size() != n) assert(end != nullptr);
if(end-begin != n)
MIGRAPH_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) + MIGRAPH_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) +
" but given " + std::to_string(shapes->size())); " but given " + std::to_string(end-begin));
return *this; return *this;
} }
const check_shapes& only_dims(std::size_t n) const const check_shapes& only_dims(std::size_t n) const
{ {
assert(shapes != nullptr); assert(begin != nullptr);
if(!shapes->empty()) assert(end != nullptr);
if(begin!=end)
{ {
if(shapes->front().lens().size() != n) if(begin->lens().size() != n)
MIGRAPH_THROW(prefix() + "Only " + std::to_string(n) + "d supported"); MIGRAPH_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
} }
return *this; return *this;
...@@ -105,18 +112,37 @@ struct check_shapes ...@@ -105,18 +112,37 @@ struct check_shapes
template <class F> template <class F>
bool same(F f) const bool same(F f) const
{ {
assert(shapes != nullptr); assert(begin != nullptr);
if(shapes->empty()) assert(end != nullptr);
if(begin==end)
return true; return true;
auto&& key = f(shapes->front()); auto&& key = f(*begin);
return this->all_of([&](const shape& s) { return f(s) == key; }); return this->all_of([&](const shape& s) { return f(s) == key; });
} }
template <class Predicate> template <class Predicate>
bool all_of(Predicate p) const bool all_of(Predicate p) const
{ {
assert(shapes != nullptr); assert(begin != nullptr);
return std::all_of(shapes->begin(), shapes->end(), p); assert(end != nullptr);
return std::all_of(begin, end, p);
}
const shape* get(long i)
{
if(i < 0)
return end-i;
return begin+i;
}
check_shapes slice(long start)
{
return {get(start), end, name};
}
check_shapes slice(long start, long last)
{
return {get(start), get(last), name};
} }
}; };
......
...@@ -264,6 +264,7 @@ auto any_of(Ts... ms) ...@@ -264,6 +264,7 @@ auto any_of(Ts... ms)
} }
MIGRAPH_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); } MIGRAPH_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); }
MIGRAPH_PRED_MATCHER(broadcast_shape, instruction_ref ins) { return ins->get_shape().broadcasted(); }
inline auto name(std::string name) inline auto name(std::string name)
{ {
......
#include <migraph/gpu/fuse_ops.hpp> #include <migraph/gpu/fuse_ops.hpp>
#include <migraph/matcher.hpp> #include <migraph/matcher.hpp>
#include <migraph/gpu/miopen.hpp>
#include <migraph/gpu/convolution.hpp>
#include <migraph/gpu/device/add_relu.hpp> #include <migraph/gpu/device/add_relu.hpp>
#include <migraph/instruction.hpp> #include <migraph/instruction.hpp>
...@@ -7,6 +9,78 @@ namespace migraph { ...@@ -7,6 +9,78 @@ namespace migraph {
namespace gpu { namespace gpu {
struct fusion
{
using op_t = miopenFusionOpDescriptor_t;
shared<fusion_plan_descriptor> fp;
// Used as a temporary hack to keep descriptor references alive
std::vector<std::shared_ptr<void>> storage;
template<class T>
auto keep_alive(T x)
{
auto result = share(std::move(x));
storage.push_back(result);
return result;
}
fusion(const shape& input)
// : fp(make_fusion_plan(input))
{
auto t = make_tensor(input);
fp = make_fusion_plan(t);
keep_alive(std::move(t));
}
op_t operator[](std::size_t i) const
{
op_t result;
auto status = miopenFusionPlanGetOp(fp.get(), i, &result);
if(status != miopenStatusSuccess)
MIGRAPH_THROW("Failed retrieving operator at " + std::to_string(i));
return result;
}
auto get() const
{
return fp.get();
}
op_t create_bias(const shape& bias)
{
op_t result;
auto b = shape{bias.type(), {1, bias.lens().at(1), 1, 1}};
auto t = keep_alive(make_tensor(b));
auto status = miopenCreateOpBiasForward(fp.get(), &result, t.get());
if(status != miopenStatusSuccess)
MIGRAPH_THROW("Creating operator failed");
return result;
}
op_t create_relu()
{
op_t result;
auto status = miopenCreateOpActivationForward(fp.get(), &result, miopenActivationRELU);
if(status != miopenStatusSuccess)
MIGRAPH_THROW("Creating operator failed");
return result;
}
op_t create_conv(const op::convolution& op, const shape& weights)
{
op_t result;
auto cd = keep_alive(make_conv(op));
auto t = keep_alive(make_tensor(weights));
auto status = miopenCreateOpConvForward(fp.get(), &result, cd.get(), t.get());
if(status != miopenStatusSuccess)
MIGRAPH_THROW("Creating operator failed");
return result;
}
};
struct hip_add_relu struct hip_add_relu
{ {
std::string name() const { return "hip::add_relu"; } std::string name() const { return "hip::add_relu"; }
...@@ -26,7 +100,7 @@ struct match_add_relu ...@@ -26,7 +100,7 @@ struct match_add_relu
{ {
auto matcher() const auto matcher() const
{ {
return match::name("gpu::relu")(match::args(match::name("gpu::add").bind("add"))); return match::name("gpu::relu")(match::arg(0)(match::name("gpu::add").bind("add")));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
...@@ -40,7 +114,86 @@ struct match_add_relu ...@@ -40,7 +114,86 @@ struct match_add_relu
} }
}; };
void fuse_ops::apply(program& p) const { match::find_matches(p, match_add_relu{}); } struct miopen_conv_bias
{
op::convolution op;
fusion f;
fusion::op_t conv;
fusion::op_t bias;
miopen_conv_bias(op::convolution c, shape input, shape weights, shape b)
: op(c), f(input)
{
f.create_conv(op, weights);
f.create_bias(b);
}
std::string name() const { return "gpu::conv_bias"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(5);
// TODO: Check slices
return op.compute_shape({inputs.at(0), inputs.at(1)});
}
argument compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{
auto fargs = make_fused_args();
float alpha = 1, beta = 0;
auto x = make_tensor(args[0].get_shape());
auto y = make_tensor(output_shape);
miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit());
miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit());
miopenExecuteFusionPlan(ctx.handle.get(), f.get(), x.get(), args[0].implicit(), y.get(), args[4].implicit(), fargs.get());
return args.at(4);
}
shape compile(context& ctx)
{
int algo_count = 1;
miopenConvFwdAlgorithm_t algo;
miopenFusionPlanConvolutionGetAlgo(f.get(), 1, &algo_count, &algo);
std::size_t ws_size = 0;
miopenFusionPlanGetWorkSpaceSize(ctx.handle.get(), f.get(), &ws_size, algo);
auto status = miopenCompileFusionPlan(ctx.handle.get(), f.get());
if(status != miopenStatusSuccess)
MIGRAPH_THROW("Compiling fusion plan failed");
return shape{shape::int8_type, {ws_size}};
}
};
struct match_conv_bias
{
context * ctx = nullptr;
auto matcher() const
{
return match::name("gpu::add")(match::any_of(
match::all_of(match::arg(0)(match::broadcast_shape().bind("bias")), match::arg(1)(match::name("gpu::convolution").bind("conv"))),
match::all_of(match::arg(1)(match::broadcast_shape().bind("bias")), match::arg(0)(match::name("gpu::convolution").bind("conv")))
));
}
void apply(program& p, match::matcher_result r) const
{
auto conv_ins = r.instructions["conv"];
auto bias_ins = r.instructions["bias"];
auto input_ins = conv_ins->inputs().at(0);
auto weights_ins = conv_ins->inputs().at(1);
auto conv_op = any_cast<miopen_convolution>(conv_ins->get_operator()).op;
auto ins = r.result;
auto alloc_ins = ins->inputs().back();
auto old_ws_ins = conv_ins->inputs().at(2);
miopen_conv_bias cb{conv_op, input_ins->get_shape(), weights_ins->get_shape(), bias_ins->get_shape()};
// TODO: Insert ws allocation
auto ws = cb.compile(*ctx);
p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
}
};
void fuse_ops::apply(program& p) const {
match::find_matches(p, match_add_relu{}, match_conv_bias{ctx});
}
} // namespace gpu } // namespace gpu
......
...@@ -10,6 +10,7 @@ namespace gpu { ...@@ -10,6 +10,7 @@ namespace gpu {
struct fuse_ops struct fuse_ops
{ {
context * ctx = nullptr;
std::string name() const { return "gpu::fuse_ops"; } std::string name() const { return "gpu::fuse_ops"; }
void apply(program& p) const; void apply(program& p) const;
}; };
......
...@@ -87,12 +87,18 @@ inline activation_descriptor make_relu() ...@@ -87,12 +87,18 @@ inline activation_descriptor make_relu()
return ad; return ad;
} }
inline fusion_plan_descriptor make_fusion_plan(const migraph::shape& input) inline fusion_plan_descriptor make_fusion_plan(const shape& input)
{ {
auto t = make_tensor(input); auto t = make_tensor(input);
return make_obj<fusion_plan_descriptor>(&miopenCreateFusionPlan, miopenVerticalFusion, t.get()); return make_obj<fusion_plan_descriptor>(&miopenCreateFusionPlan, miopenVerticalFusion, t.get());
} }
// Temporary hack to workaround memory problems in miopen
inline fusion_plan_descriptor make_fusion_plan(const tensor_descriptor& input)
{
return make_obj<fusion_plan_descriptor>(&miopenCreateFusionPlan, miopenVerticalFusion, input.get());
}
inline fused_operator_args make_fused_args() inline fused_operator_args make_fused_args()
{ {
return make_obj<fused_operator_args>(&miopenCreateOperatorArgs); return make_obj<fused_operator_args>(&miopenCreateOperatorArgs);
......
...@@ -29,7 +29,7 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const ...@@ -29,7 +29,7 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
simplify_reshapes{}, simplify_reshapes{},
dead_code_elimination{}, dead_code_elimination{},
lowering{ctx}, lowering{ctx},
fuse_ops{}, fuse_ops{&ctx},
dead_code_elimination{}, dead_code_elimination{},
eliminate_contiguous{}, eliminate_contiguous{},
dead_code_elimination{}, dead_code_elimination{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment