Commit b3e2901e authored by Paul's avatar Paul
Browse files

Formatting

parent 9979fada
...@@ -14,8 +14,9 @@ shape hip_concat::compute_shape(std::vector<shape> inputs) const ...@@ -14,8 +14,9 @@ shape hip_concat::compute_shape(std::vector<shape> inputs) const
return op.compute_shape(inputs); return op.compute_shape(inputs);
} }
argument argument hip_concat::compute(context& ctx,
hip_concat::compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const const shape& output_shape,
const std::vector<argument>& args) const
{ {
std::vector<std::size_t> offsets = op.compute_offsets(output_shape, args); std::vector<std::size_t> offsets = op.compute_offsets(output_shape, args);
return device::concat(ctx.get_stream().get(), output_shape, args, offsets); return device::concat(ctx.get_stream().get(), output_shape, args, offsets);
......
...@@ -12,8 +12,9 @@ shape miopen_contiguous::compute_shape(const std::vector<shape>& inputs) const ...@@ -12,8 +12,9 @@ shape miopen_contiguous::compute_shape(const std::vector<shape>& inputs) const
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2);
return op.compute_shape({inputs.at(0)}); return op.compute_shape({inputs.at(0)});
} }
argument argument miopen_contiguous::compute(context& ctx,
miopen_contiguous::compute(context& ctx, shape output_shape, const std::vector<argument>& args) const shape output_shape,
const std::vector<argument>& args) const
{ {
assert(output_shape == args[1].get_shape()); assert(output_shape == args[1].get_shape());
assert(output_shape.standard()); assert(output_shape.standard());
......
...@@ -10,7 +10,11 @@ void add(hipStream_t stream, const argument& result, const argument& arg1, const ...@@ -10,7 +10,11 @@ void add(hipStream_t stream, const argument& result, const argument& arg1, const
nary(stream, result, arg1, arg2)([](auto x, auto y) { return x + y; }); nary(stream, result, arg1, arg2)([](auto x, auto y) { return x + y; });
} }
void add(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2, const argument& arg3) void add(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{ {
nary(stream, result, arg1, arg2, arg3)([](auto x, auto y, auto z) { return x + y + z; }); nary(stream, result, arg1, arg2, arg3)([](auto x, auto y, auto z) { return x + y + z; });
} }
......
...@@ -5,12 +5,17 @@ namespace migraph { ...@@ -5,12 +5,17 @@ namespace migraph {
namespace gpu { namespace gpu {
namespace device { namespace device {
void add_relu(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) void add_relu(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2)
{ {
nary(stream, result, arg1, arg2)([](auto x, auto y) { return std::max<decltype(x + y)>(0, x + y); }); nary(stream, result, arg1, arg2)(
[](auto x, auto y) { return std::max<decltype(x + y)>(0, x + y); });
} }
void add_relu(hipStream_t stream, const argument& result, void add_relu(hipStream_t stream,
const argument& result,
const argument& arg1, const argument& arg1,
const argument& arg2, const argument& arg2,
const argument& arg3) const argument& arg3)
......
...@@ -8,7 +8,8 @@ namespace migraph { ...@@ -8,7 +8,8 @@ namespace migraph {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument concat(hipStream_t stream, const migraph::shape& output_shape, argument concat(hipStream_t stream,
const migraph::shape& output_shape,
std::vector<migraph::argument> args, std::vector<migraph::argument> args,
std::vector<std::size_t> offsets) std::vector<std::size_t> offsets)
{ {
......
...@@ -52,8 +52,12 @@ auto nary_nonstandard_impl(hipStream_t stream, F f, argument result, Arguments.. ...@@ -52,8 +52,12 @@ auto nary_nonstandard_impl(hipStream_t stream, F f, argument result, Arguments..
} }
template <class F> template <class F>
void trinary_broadcast_vec_impl( void trinary_broadcast_vec_impl(hipStream_t stream,
hipStream_t stream, F f, const argument& result, const argument& arg1, const argument& arg2, const argument& arg3) F f,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{ {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = arg3.get_shape(); const auto& b_shape = arg3.get_shape();
...@@ -107,8 +111,12 @@ void trinary_broadcast_vec_impl( ...@@ -107,8 +111,12 @@ void trinary_broadcast_vec_impl(
} }
template <class F> template <class F>
void trinary_broadcast_impl( void trinary_broadcast_impl(hipStream_t stream,
hipStream_t stream, F f, const argument& result, const argument& arg1, const argument& arg2, const argument& arg3) F f,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{ {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = arg3.get_shape(); const auto& b_shape = arg3.get_shape();
...@@ -154,10 +162,8 @@ void trinary_broadcast_impl( ...@@ -154,10 +162,8 @@ void trinary_broadcast_impl(
} }
template <class F> template <class F>
void binary_broadcast_vec_impl(hipStream_t stream, F f, void binary_broadcast_vec_impl(
const argument& result, hipStream_t stream, F f, const argument& result, const argument& arg1, const argument& arg2)
const argument& arg1,
const argument& arg2)
{ {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = arg2.get_shape(); const auto& b_shape = arg2.get_shape();
...@@ -209,7 +215,8 @@ void binary_broadcast_vec_impl(hipStream_t stream, F f, ...@@ -209,7 +215,8 @@ void binary_broadcast_vec_impl(hipStream_t stream, F f,
} }
template <class F> template <class F>
void binary_broadcast_impl(hipStream_t stream, F f, const argument& result, const argument& arg1, const argument& arg2) void binary_broadcast_impl(
hipStream_t stream, F f, const argument& result, const argument& arg1, const argument& arg2)
{ {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = arg2.get_shape(); const auto& b_shape = arg2.get_shape();
...@@ -321,7 +328,8 @@ auto nary(hipStream_t stream, argument result, Arguments... args) ...@@ -321,7 +328,8 @@ auto nary(hipStream_t stream, argument result, Arguments... args)
return [=](auto f) { nary_impl(stream, f, result, args...); }; return [=](auto f) { nary_impl(stream, f, result, args...); };
} }
inline auto nary(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) inline auto
nary(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{ {
return [=](auto f) { return [=](auto f) {
// TODO: Check result and arg1 shape is the same // TODO: Check result and arg1 shape is the same
...@@ -349,8 +357,11 @@ inline auto nary(hipStream_t stream, const argument& result, const argument& arg ...@@ -349,8 +357,11 @@ inline auto nary(hipStream_t stream, const argument& result, const argument& arg
}; };
} }
inline auto inline auto nary(hipStream_t stream,
nary(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2, const argument& arg3) const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{ {
return [=](auto f) { return [=](auto f) {
// TODO: Check result and arg1 shape is the same // TODO: Check result and arg1 shape is the same
......
...@@ -11,7 +11,11 @@ namespace device { ...@@ -11,7 +11,11 @@ namespace device {
void add(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2); void add(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2);
void add(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2, const argument& arg3); void add(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -9,9 +9,13 @@ namespace migraph { ...@@ -9,9 +9,13 @@ namespace migraph {
namespace gpu { namespace gpu {
namespace device { namespace device {
void add_relu(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2); void add_relu(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2);
void add_relu(hipStream_t stream, const argument& result, void add_relu(hipStream_t stream,
const argument& result,
const argument& arg1, const argument& arg1,
const argument& arg2, const argument& arg2,
const argument& arg3); const argument& arg3);
......
...@@ -8,8 +8,10 @@ namespace migraph { ...@@ -8,8 +8,10 @@ namespace migraph {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument argument concat(hipStream_t stream,
concat(hipStream_t stream, const shape& output_shape, std::vector<argument> args, std::vector<std::size_t> offsets); const shape& output_shape,
std::vector<argument> args,
std::vector<std::size_t> offsets);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment