Commit 7489c75d authored by Paul's avatar Paul
Browse files

Formatting

parent cb9bfaf4
......@@ -7,13 +7,13 @@ namespace gpu {
namespace device {
void mul_add_relu(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
nary(stream, result, arg1, arg2, arg3)(
[](auto x, auto a, auto b) { return std::max<decltype(a*x + b)>(0, a*x + b); });
[](auto x, auto a, auto b) { return std::max<decltype(a * x + b)>(0, a * x + b); });
}
void add_relu(hipStream_t stream,
......
......@@ -230,7 +230,8 @@ struct hip_mul_add_relu
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::mul_add_relu(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
device::mul_add_relu(
ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
......@@ -354,8 +355,8 @@ struct find_mul_add_relu
void apply(program& p, match::matcher_result r) const
{
auto mul_add_ins = r.instructions["mul_add"];
auto ins = r.result;
auto args = mul_add_ins->inputs();
auto ins = r.result;
auto args = mul_add_ins->inputs();
// Use the allocation from the relu operator
args.back() = ins->inputs().back();
......
......@@ -12,10 +12,10 @@ namespace gpu {
namespace device {
void mul_add_relu(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3);
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3);
void add_relu(hipStream_t stream,
const argument& result,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment