Commit 7c416eb9 authored by Paul's avatar Paul
Browse files

Formatting

parent 7f37c6ab
......@@ -12,14 +12,15 @@ struct check_shapes
const shape* end;
const std::string name;
check_shapes(const shape* b, const shape* e, const std::string& n)
: begin(b), end(e), name(n)
{}
check_shapes(const shape* b, const shape* e, const std::string& n) : begin(b), end(e), name(n)
{
}
check_shapes(const std::vector<shape>& s) : begin(s.data()), end(s.data()+s.size()) {}
check_shapes(const std::vector<shape>& s) : begin(s.data()), end(s.data() + s.size()) {}
template <class Op>
check_shapes(const std::vector<shape>& s, const Op& op) : begin(s.data()), end(s.data()+s.size()), name(op.name())
check_shapes(const std::vector<shape>& s, const Op& op)
: begin(s.data()), end(s.data() + s.size()), name(op.name())
{
}
......@@ -35,9 +36,9 @@ struct check_shapes
{
assert(begin != nullptr);
assert(end != nullptr);
if(end-begin != n)
if(end - begin != n)
MIGRAPH_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) +
" but given " + std::to_string(end-begin));
" but given " + std::to_string(end - begin));
return *this;
}
......@@ -45,7 +46,7 @@ struct check_shapes
{
assert(begin != nullptr);
assert(end != nullptr);
if(begin!=end)
if(begin != end)
{
if(begin->lens().size() != n)
MIGRAPH_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
......@@ -114,7 +115,7 @@ struct check_shapes
{
assert(begin != nullptr);
assert(end != nullptr);
if(begin==end)
if(begin == end)
return true;
auto&& key = f(*begin);
return this->all_of([&](const shape& s) { return f(s) == key; });
......@@ -131,19 +132,13 @@ struct check_shapes
const shape* get(long i)
{
if(i < 0)
return end-i;
return begin+i;
return end - i;
return begin + i;
}
check_shapes slice(long start)
{
return {get(start), end, name};
}
check_shapes slice(long start) { return {get(start), end, name}; }
check_shapes slice(long start, long last)
{
return {get(start), get(last), name};
}
check_shapes slice(long start, long last) { return {get(start), get(last), name}; }
};
} // namespace migraph
......
......@@ -264,7 +264,10 @@ auto any_of(Ts... ms)
}
MIGRAPH_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); }
MIGRAPH_PRED_MATCHER(broadcast_shape, instruction_ref ins) { return ins->get_shape().broadcasted(); }
MIGRAPH_PRED_MATCHER(broadcast_shape, instruction_ref ins)
{
return ins->get_shape().broadcasted();
}
inline auto name(std::string name)
{
......
......@@ -17,7 +17,7 @@ struct fusion
// Used as a temporary hack to keep descriptor references alive
std::vector<std::shared_ptr<void>> storage;
template<class T>
template <class T>
auto keep_alive(T x)
{
auto result = share(std::move(x));
......@@ -33,7 +33,6 @@ struct fusion
keep_alive(std::move(t));
}
op_t operator[](std::size_t i) const
{
op_t result;
......@@ -43,10 +42,7 @@ struct fusion
return result;
}
auto get() const
{
return fp.get();
}
auto get() const { return fp.get(); }
op_t create_bias(const shape& bias)
{
......@@ -78,7 +74,6 @@ struct fusion
MIGRAPH_THROW("Creating operator failed");
return result;
}
};
struct hip_add_relu
......@@ -121,8 +116,7 @@ struct miopen_conv_bias
fusion::op_t conv;
fusion::op_t bias;
miopen_conv_bias(op::convolution c, shape input, shape weights, shape b)
: op(c), f(input)
miopen_conv_bias(op::convolution c, shape input, shape weights, shape b) : op(c), f(input)
{
f.create_conv(op, weights);
f.create_bias(b);
......@@ -135,7 +129,8 @@ struct miopen_conv_bias
// TODO: Check slices
return op.compute_shape({inputs.at(0), inputs.at(1)});
}
argument compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{
auto fargs = make_fused_args();
float alpha = 1, beta = 0;
......@@ -143,7 +138,13 @@ struct miopen_conv_bias
auto y = make_tensor(output_shape);
miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit());
miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit());
miopenExecuteFusionPlan(ctx.handle.get(), f.get(), x.get(), args[0].implicit(), y.get(), args[4].implicit(), fargs.get());
miopenExecuteFusionPlan(ctx.handle.get(),
f.get(),
x.get(),
args[0].implicit(),
y.get(),
args[4].implicit(),
fargs.get());
return args.at(4);
}
......@@ -163,13 +164,14 @@ struct miopen_conv_bias
struct match_conv_bias
{
context * ctx = nullptr;
context* ctx = nullptr;
auto matcher() const
{
return match::name("gpu::add")(match::any_of(
match::all_of(match::arg(0)(match::broadcast_shape().bind("bias")), match::arg(1)(match::name("gpu::convolution").bind("conv"))),
match::all_of(match::arg(1)(match::broadcast_shape().bind("bias")), match::arg(0)(match::name("gpu::convolution").bind("conv")))
));
match::all_of(match::arg(0)(match::broadcast_shape().bind("bias")),
match::arg(1)(match::name("gpu::convolution").bind("conv"))),
match::all_of(match::arg(1)(match::broadcast_shape().bind("bias")),
match::arg(0)(match::name("gpu::convolution").bind("conv")))));
}
void apply(program& p, match::matcher_result r) const
......@@ -183,7 +185,8 @@ struct match_conv_bias
auto alloc_ins = ins->inputs().back();
auto old_ws_ins = conv_ins->inputs().at(2);
miopen_conv_bias cb{conv_op, input_ins->get_shape(), weights_ins->get_shape(), bias_ins->get_shape()};
miopen_conv_bias cb{
conv_op, input_ins->get_shape(), weights_ins->get_shape(), bias_ins->get_shape()};
// TODO: Insert ws allocation
auto ws = cb.compile(*ctx);
......@@ -191,7 +194,8 @@ struct match_conv_bias
}
};
void fuse_ops::apply(program& p) const {
void fuse_ops::apply(program& p) const
{
match::find_matches(p, match_add_relu{}, match_conv_bias{ctx});
}
......
......@@ -10,7 +10,7 @@ namespace gpu {
struct fuse_ops
{
context * ctx = nullptr;
context* ctx = nullptr;
std::string name() const { return "gpu::fuse_ops"; }
void apply(program& p) const;
};
......
......@@ -96,7 +96,8 @@ inline fusion_plan_descriptor make_fusion_plan(const shape& input)
// Temporary hack to workaround memory problems in miopen
inline fusion_plan_descriptor make_fusion_plan(const tensor_descriptor& input)
{
return make_obj<fusion_plan_descriptor>(&miopenCreateFusionPlan, miopenVerticalFusion, input.get());
return make_obj<fusion_plan_descriptor>(
&miopenCreateFusionPlan, miopenVerticalFusion, input.get());
}
inline fused_operator_args make_fused_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment