Commit 7489c75d authored by Paul's avatar Paul
Browse files

Formatting

parent cb9bfaf4
...@@ -7,13 +7,13 @@ namespace gpu { ...@@ -7,13 +7,13 @@ namespace gpu {
namespace device { namespace device {
void mul_add_relu(hipStream_t stream, void mul_add_relu(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg1, const argument& arg1,
const argument& arg2, const argument& arg2,
const argument& arg3) const argument& arg3)
{ {
nary(stream, result, arg1, arg2, arg3)( nary(stream, result, arg1, arg2, arg3)(
[](auto x, auto a, auto b) { return std::max<decltype(a*x + b)>(0, a*x + b); }); [](auto x, auto a, auto b) { return std::max<decltype(a * x + b)>(0, a * x + b); });
} }
void add_relu(hipStream_t stream, void add_relu(hipStream_t stream,
......
...@@ -230,7 +230,8 @@ struct hip_mul_add_relu ...@@ -230,7 +230,8 @@ struct hip_mul_add_relu
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
device::mul_add_relu(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2)); device::mul_add_relu(
ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3); return args.at(3);
} }
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
...@@ -354,8 +355,8 @@ struct find_mul_add_relu ...@@ -354,8 +355,8 @@ struct find_mul_add_relu
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
{ {
auto mul_add_ins = r.instructions["mul_add"]; auto mul_add_ins = r.instructions["mul_add"];
auto ins = r.result; auto ins = r.result;
auto args = mul_add_ins->inputs(); auto args = mul_add_ins->inputs();
// Use the allocation from the relu operator // Use the allocation from the relu operator
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
......
...@@ -12,10 +12,10 @@ namespace gpu { ...@@ -12,10 +12,10 @@ namespace gpu {
namespace device { namespace device {
void mul_add_relu(hipStream_t stream, void mul_add_relu(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg1, const argument& arg1,
const argument& arg2, const argument& arg2,
const argument& arg3); const argument& arg3);
void add_relu(hipStream_t stream, void add_relu(hipStream_t stream,
const argument& result, const argument& result,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment