Commit 7c416eb9 authored by Paul's avatar Paul
Browse files

Formatting

parent 7f37c6ab
...@@ -12,14 +12,15 @@ struct check_shapes ...@@ -12,14 +12,15 @@ struct check_shapes
const shape* end; const shape* end;
const std::string name; const std::string name;
check_shapes(const shape* b, const shape* e, const std::string& n) check_shapes(const shape* b, const shape* e, const std::string& n) : begin(b), end(e), name(n)
: begin(b), end(e), name(n) {
{} }
check_shapes(const std::vector<shape>& s) : begin(s.data()), end(s.data()+s.size()) {} check_shapes(const std::vector<shape>& s) : begin(s.data()), end(s.data() + s.size()) {}
template <class Op> template <class Op>
check_shapes(const std::vector<shape>& s, const Op& op) : begin(s.data()), end(s.data()+s.size()), name(op.name()) check_shapes(const std::vector<shape>& s, const Op& op)
: begin(s.data()), end(s.data() + s.size()), name(op.name())
{ {
} }
...@@ -35,9 +36,9 @@ struct check_shapes ...@@ -35,9 +36,9 @@ struct check_shapes
{ {
assert(begin != nullptr); assert(begin != nullptr);
assert(end != nullptr); assert(end != nullptr);
if(end-begin != n) if(end - begin != n)
MIGRAPH_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) + MIGRAPH_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) +
" but given " + std::to_string(end-begin)); " but given " + std::to_string(end - begin));
return *this; return *this;
} }
...@@ -45,7 +46,7 @@ struct check_shapes ...@@ -45,7 +46,7 @@ struct check_shapes
{ {
assert(begin != nullptr); assert(begin != nullptr);
assert(end != nullptr); assert(end != nullptr);
if(begin!=end) if(begin != end)
{ {
if(begin->lens().size() != n) if(begin->lens().size() != n)
MIGRAPH_THROW(prefix() + "Only " + std::to_string(n) + "d supported"); MIGRAPH_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
...@@ -114,7 +115,7 @@ struct check_shapes ...@@ -114,7 +115,7 @@ struct check_shapes
{ {
assert(begin != nullptr); assert(begin != nullptr);
assert(end != nullptr); assert(end != nullptr);
if(begin==end) if(begin == end)
return true; return true;
auto&& key = f(*begin); auto&& key = f(*begin);
return this->all_of([&](const shape& s) { return f(s) == key; }); return this->all_of([&](const shape& s) { return f(s) == key; });
...@@ -131,19 +132,13 @@ struct check_shapes ...@@ -131,19 +132,13 @@ struct check_shapes
const shape* get(long i) const shape* get(long i)
{ {
if(i < 0) if(i < 0)
return end-i; return end - i;
return begin+i; return begin + i;
} }
check_shapes slice(long start) check_shapes slice(long start) { return {get(start), end, name}; }
{
return {get(start), end, name};
}
check_shapes slice(long start, long last) check_shapes slice(long start, long last) { return {get(start), get(last), name}; }
{
return {get(start), get(last), name};
}
}; };
} // namespace migraph } // namespace migraph
......
...@@ -264,7 +264,10 @@ auto any_of(Ts... ms) ...@@ -264,7 +264,10 @@ auto any_of(Ts... ms)
} }
MIGRAPH_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); } MIGRAPH_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); }
MIGRAPH_PRED_MATCHER(broadcast_shape, instruction_ref ins) { return ins->get_shape().broadcasted(); } MIGRAPH_PRED_MATCHER(broadcast_shape, instruction_ref ins)
{
return ins->get_shape().broadcasted();
}
inline auto name(std::string name) inline auto name(std::string name)
{ {
......
...@@ -17,7 +17,7 @@ struct fusion ...@@ -17,7 +17,7 @@ struct fusion
// Used as a temporary hack to keep descriptor references alive // Used as a temporary hack to keep descriptor references alive
std::vector<std::shared_ptr<void>> storage; std::vector<std::shared_ptr<void>> storage;
template<class T> template <class T>
auto keep_alive(T x) auto keep_alive(T x)
{ {
auto result = share(std::move(x)); auto result = share(std::move(x));
...@@ -29,11 +29,10 @@ struct fusion ...@@ -29,11 +29,10 @@ struct fusion
// : fp(make_fusion_plan(input)) // : fp(make_fusion_plan(input))
{ {
auto t = make_tensor(input); auto t = make_tensor(input);
fp = make_fusion_plan(t); fp = make_fusion_plan(t);
keep_alive(std::move(t)); keep_alive(std::move(t));
} }
op_t operator[](std::size_t i) const op_t operator[](std::size_t i) const
{ {
op_t result; op_t result;
...@@ -43,16 +42,13 @@ struct fusion ...@@ -43,16 +42,13 @@ struct fusion
return result; return result;
} }
auto get() const auto get() const { return fp.get(); }
{
return fp.get();
}
op_t create_bias(const shape& bias) op_t create_bias(const shape& bias)
{ {
op_t result; op_t result;
auto b = shape{bias.type(), {1, bias.lens().at(1), 1, 1}}; auto b = shape{bias.type(), {1, bias.lens().at(1), 1, 1}};
auto t = keep_alive(make_tensor(b)); auto t = keep_alive(make_tensor(b));
auto status = miopenCreateOpBiasForward(fp.get(), &result, t.get()); auto status = miopenCreateOpBiasForward(fp.get(), &result, t.get());
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPH_THROW("Creating operator failed"); MIGRAPH_THROW("Creating operator failed");
...@@ -71,14 +67,13 @@ struct fusion ...@@ -71,14 +67,13 @@ struct fusion
op_t create_conv(const op::convolution& op, const shape& weights) op_t create_conv(const op::convolution& op, const shape& weights)
{ {
op_t result; op_t result;
auto cd = keep_alive(make_conv(op)); auto cd = keep_alive(make_conv(op));
auto t = keep_alive(make_tensor(weights)); auto t = keep_alive(make_tensor(weights));
auto status = miopenCreateOpConvForward(fp.get(), &result, cd.get(), t.get()); auto status = miopenCreateOpConvForward(fp.get(), &result, cd.get(), t.get());
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPH_THROW("Creating operator failed"); MIGRAPH_THROW("Creating operator failed");
return result; return result;
} }
}; };
struct hip_add_relu struct hip_add_relu
...@@ -121,8 +116,7 @@ struct miopen_conv_bias ...@@ -121,8 +116,7 @@ struct miopen_conv_bias
fusion::op_t conv; fusion::op_t conv;
fusion::op_t bias; fusion::op_t bias;
miopen_conv_bias(op::convolution c, shape input, shape weights, shape b) miopen_conv_bias(op::convolution c, shape input, shape weights, shape b) : op(c), f(input)
: op(c), f(input)
{ {
f.create_conv(op, weights); f.create_conv(op, weights);
f.create_bias(b); f.create_bias(b);
...@@ -135,15 +129,22 @@ struct miopen_conv_bias ...@@ -135,15 +129,22 @@ struct miopen_conv_bias
// TODO: Check slices // TODO: Check slices
return op.compute_shape({inputs.at(0), inputs.at(1)}); return op.compute_shape({inputs.at(0), inputs.at(1)});
} }
argument compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{ {
auto fargs = make_fused_args(); auto fargs = make_fused_args();
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
auto x = make_tensor(args[0].get_shape()); auto x = make_tensor(args[0].get_shape());
auto y = make_tensor(output_shape); auto y = make_tensor(output_shape);
miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit()); miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit());
miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit()); miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit());
miopenExecuteFusionPlan(ctx.handle.get(), f.get(), x.get(), args[0].implicit(), y.get(), args[4].implicit(), fargs.get()); miopenExecuteFusionPlan(ctx.handle.get(),
f.get(),
x.get(),
args[0].implicit(),
y.get(),
args[4].implicit(),
fargs.get());
return args.at(4); return args.at(4);
} }
...@@ -163,36 +164,39 @@ struct miopen_conv_bias ...@@ -163,36 +164,39 @@ struct miopen_conv_bias
struct match_conv_bias struct match_conv_bias
{ {
context * ctx = nullptr; context* ctx = nullptr;
auto matcher() const auto matcher() const
{ {
return match::name("gpu::add")(match::any_of( return match::name("gpu::add")(match::any_of(
match::all_of(match::arg(0)(match::broadcast_shape().bind("bias")), match::arg(1)(match::name("gpu::convolution").bind("conv"))), match::all_of(match::arg(0)(match::broadcast_shape().bind("bias")),
match::all_of(match::arg(1)(match::broadcast_shape().bind("bias")), match::arg(0)(match::name("gpu::convolution").bind("conv"))) match::arg(1)(match::name("gpu::convolution").bind("conv"))),
)); match::all_of(match::arg(1)(match::broadcast_shape().bind("bias")),
match::arg(0)(match::name("gpu::convolution").bind("conv")))));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
{ {
auto conv_ins = r.instructions["conv"]; auto conv_ins = r.instructions["conv"];
auto bias_ins = r.instructions["bias"]; auto bias_ins = r.instructions["bias"];
auto input_ins = conv_ins->inputs().at(0); auto input_ins = conv_ins->inputs().at(0);
auto weights_ins = conv_ins->inputs().at(1); auto weights_ins = conv_ins->inputs().at(1);
auto conv_op = any_cast<miopen_convolution>(conv_ins->get_operator()).op; auto conv_op = any_cast<miopen_convolution>(conv_ins->get_operator()).op;
auto ins = r.result; auto ins = r.result;
auto alloc_ins = ins->inputs().back(); auto alloc_ins = ins->inputs().back();
auto old_ws_ins = conv_ins->inputs().at(2); auto old_ws_ins = conv_ins->inputs().at(2);
miopen_conv_bias cb{conv_op, input_ins->get_shape(), weights_ins->get_shape(), bias_ins->get_shape()}; miopen_conv_bias cb{
conv_op, input_ins->get_shape(), weights_ins->get_shape(), bias_ins->get_shape()};
// TODO: Insert ws allocation // TODO: Insert ws allocation
auto ws = cb.compile(*ctx); auto ws = cb.compile(*ctx);
p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins); p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
} }
}; };
void fuse_ops::apply(program& p) const { void fuse_ops::apply(program& p) const
match::find_matches(p, match_add_relu{}, match_conv_bias{ctx}); {
match::find_matches(p, match_add_relu{}, match_conv_bias{ctx});
} }
} // namespace gpu } // namespace gpu
......
...@@ -10,7 +10,7 @@ namespace gpu { ...@@ -10,7 +10,7 @@ namespace gpu {
struct fuse_ops struct fuse_ops
{ {
context * ctx = nullptr; context* ctx = nullptr;
std::string name() const { return "gpu::fuse_ops"; } std::string name() const { return "gpu::fuse_ops"; }
void apply(program& p) const; void apply(program& p) const;
}; };
......
...@@ -96,7 +96,8 @@ inline fusion_plan_descriptor make_fusion_plan(const shape& input) ...@@ -96,7 +96,8 @@ inline fusion_plan_descriptor make_fusion_plan(const shape& input)
// Temporary hack to workaround memory problems in miopen // Temporary hack to workaround memory problems in miopen
inline fusion_plan_descriptor make_fusion_plan(const tensor_descriptor& input) inline fusion_plan_descriptor make_fusion_plan(const tensor_descriptor& input)
{ {
return make_obj<fusion_plan_descriptor>(&miopenCreateFusionPlan, miopenVerticalFusion, input.get()); return make_obj<fusion_plan_descriptor>(
&miopenCreateFusionPlan, miopenVerticalFusion, input.get());
} }
inline fused_operator_args make_fused_args() inline fused_operator_args make_fused_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment