Commit 905dbf49 authored by Paul's avatar Paul
Browse files

Add an add operator

parent b0d0c6aa
...@@ -11,6 +11,7 @@ if(NOT TARGET MIOpen) ...@@ -11,6 +11,7 @@ if(NOT TARGET MIOpen)
endif() endif()
add_library(migraph_device add_library(migraph_device
device/add.cpp
device/add_relu.cpp device/add_relu.cpp
device/contiguous.cpp device/contiguous.cpp
) )
......
#include <migraph/gpu/device/add.hpp>
#include <migraph/gpu/device/nary.hpp>
namespace migraph {
namespace gpu {
namespace device {
void add(argument result, argument arg1, argument arg2)
{
nary(std::move(result), std::move(arg1), std::move(arg2))(
[](auto x, auto y) { return x + y; });
}
} // namespace device
} // namespace gpu
} // namespace migraph
...@@ -7,7 +7,7 @@ namespace device { ...@@ -7,7 +7,7 @@ namespace device {
void add_relu(argument result, argument arg1, argument arg2) void add_relu(argument result, argument arg1, argument arg2)
{ {
nary_standard(std::move(result), std::move(arg1), std::move(arg2))( nary(std::move(result), std::move(arg1), std::move(arg2))(
[](auto x, auto y) { return max(0, x + y); }); [](auto x, auto y) { return max(0, x + y); });
} }
......
...@@ -10,18 +10,6 @@ namespace migraph { ...@@ -10,18 +10,6 @@ namespace migraph {
namespace gpu { namespace gpu {
namespace device { namespace device {
template <class... Arguments>
auto nary(argument result, Arguments... args)
{
return [=](auto f) {
if(all_of({args...}, [](const shape& s) { return s.standard(); }))
nary_standard(result, args...)(f);
else
nary_nonstandard(result, args...)(f);
};
}
template <class F, class... Arguments> template <class F, class... Arguments>
auto nary_nonstandard_impl(F f, argument result, Arguments... args) auto nary_nonstandard_impl(F f, argument result, Arguments... args)
{ {
...@@ -65,6 +53,18 @@ auto nary_standard(argument result, Arguments... args) ...@@ -65,6 +53,18 @@ auto nary_standard(argument result, Arguments... args)
}; };
} }
template <class... Arguments>
auto nary(argument result, Arguments... args)
{
return [=](auto f) {
if(all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); }))
nary_standard(result, args...)(f);
else
nary_nonstandard(result, args...)(f);
};
}
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
} // namespace migraph } // namespace migraph
......
#ifndef MIGRAPH_GUARD_RTGLIB_DEVICE_ADD_HPP
#define MIGRAPH_GUARD_RTGLIB_DEVICE_ADD_HPP
#include <migraph/argument.hpp>
namespace migraph {
namespace gpu {
namespace device {
void add(argument result, argument arg1, argument arg2);
} // namespace device
} // namespace gpu
} // namespace migraph
#endif
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <migraph/gpu/hip.hpp> #include <migraph/gpu/hip.hpp>
#include <migraph/dfor.hpp> #include <migraph/dfor.hpp>
#include <migraph/gpu/device/contiguous.hpp> #include <migraph/gpu/device/contiguous.hpp>
#include <migraph/gpu/device/add.hpp>
#include <migraph/iterator_for.hpp> #include <migraph/iterator_for.hpp>
#include <migraph/gpu/rocblas.hpp> #include <migraph/gpu/rocblas.hpp>
#include <migraph/gpu/context.hpp> #include <migraph/gpu/context.hpp>
...@@ -168,33 +169,19 @@ struct miopen_pooling ...@@ -168,33 +169,19 @@ struct miopen_pooling
} }
}; };
struct miopen_add struct hip_add
{ {
std::string name() const { return "gpu::add"; } std::string name() const { return "gpu::add"; }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(3).not_broadcasted(); check_shapes{inputs, *this}.has(3);
return inputs.at(0); return inputs.at(0);
} }
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const compute(context&, const shape&, const std::vector<argument>& args) const
{ {
float alpha = 1, beta = 0; device::add(args[2], args[0], args[1]);
auto a_desc = make_tensor(args[0].get_shape());
auto b_desc = make_tensor(args[1].get_shape());
auto c_desc = make_tensor(output_shape);
miopenOpTensor(ctx.handle.get(),
miopenTensorOpAdd,
&alpha,
a_desc.get(),
args[0].implicit(),
&alpha,
b_desc.get(),
args[1].implicit(),
&beta,
c_desc.get(),
args[2].implicit());
return args[2]; return args[2];
} }
}; };
...@@ -390,7 +377,7 @@ struct miopen_apply ...@@ -390,7 +377,7 @@ struct miopen_apply
{ {
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->result);
return prog->replace_instruction( return prog->replace_instruction(
ins, miopen_add{}, ins->arguments.at(0), ins->arguments.at(1), output); ins, hip_add{}, ins->arguments.at(0), ins->arguments.at(1), output);
} }
instruction_ref apply_gemm(instruction_ref ins) instruction_ref apply_gemm(instruction_ref ins)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment