"examples/vscode:/vscode.git/clone" did not exist on "bc3c64aa786dcfc0e2a68c26bed5ba66b74488af"
Commit 08c3a87f authored by Paul's avatar Paul
Browse files

Update unary operator

parent 79b2b1fc
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/make_signed.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -17,9 +18,12 @@ namespace migraphx { ...@@ -17,9 +18,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct abs : unary struct abs : unary<abs>
{ {
std::string name() const { return "abs"; } auto apply() const
{
return [](auto x) { return std::abs(make_signed(x)); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct acos : unary struct acos : unary<acos>
{ {
std::string name() const { return "acos"; } auto apply() const
{
return [](auto x) { return std::acos(x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct asin : unary struct asin : unary<asin>
{ {
std::string name() const { return "asin"; } auto apply() const
{
return [](auto x) { return std::asin(x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct atan : unary struct atan : unary<atan>
{ {
std::string name() const { return "atan"; } auto apply() const
{
return [](auto x) { return std::atan(x); };
}
}; };
} // namespace op } // namespace op
......
#ifndef MIGRAPHX_GUARD_OPERATORS_BINARY_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_BINARY_HPP
#define MIGRAPHX_GUARD_OPERATORS_BINARY_HPP #define MIGRAPHX_GUARD_OPERATORS_BINARY_HPP
#include <array> #include <migraphx/op/name.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
template <class Derived> template <class Derived>
struct binary struct binary : op_name<Derived>
{ {
std::string name() const
{
static const std::string& name = get_type_name<Derived>();
return name.substr(name.rfind("::") + 2);
}
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(2).same_type().same_dims(); check_shapes{inputs}.has(2).same_type().same_dims();
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct cos : unary struct cos : unary<cos>
{ {
std::string name() const { return "cos"; } auto apply() const
{
return [](auto x) { return std::cos(x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct cosh : unary struct cosh : unary<cosh>
{ {
std::string name() const { return "cosh"; } auto apply() const
{
return [](auto x) { return std::cosh(x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct exp : unary struct exp : unary<exp>
{ {
std::string name() const { return "exp"; } auto apply() const
{
return [](auto x) { return std::exp(x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct log : unary struct log : unary<log>
{ {
std::string name() const { return "log"; } auto apply() const
{
return [](auto x) { return std::log(x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct neg : unary struct neg : unary<neg>
{ {
std::string name() const { return "neg"; } auto apply() const
{
return [](auto x) { return -x; };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct relu : unary struct relu : unary<relu>
{ {
std::string name() const { return "relu"; } auto apply() const
{
return [](auto x) { return std::max(decltype(x){0}, x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct sigmoid : unary struct sigmoid : unary<sigmoid>
{ {
std::string name() const { return "sigmoid"; } auto apply() const
{
return [](auto x) { return 1.f / (1.f + std::exp(-x)); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct sin : unary struct sin : unary<sin>
{ {
std::string name() const { return "sin"; } auto apply() const
{
return [](auto x) { return std::sin(x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct sinh : unary struct sinh : unary<sinh>
{ {
std::string name() const { return "sinh"; } auto apply() const
{
return [](auto x) { return std::sinh(x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct tan : unary struct tan : unary<tan>
{ {
std::string name() const { return "tan"; } auto apply() const
{
return [](auto x) { return std::tan(x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct tanh : unary struct tanh : unary<tanh>
{ {
std::string name() const { return "tanh"; } auto apply() const
{
return [](auto x) { return std::tanh(x); };
}
}; };
} // namespace op } // namespace op
......
#ifndef MIGRAPHX_GUARD_OPERATORS_UNARY_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_UNARY_HPP
#define MIGRAPHX_GUARD_OPERATORS_UNARY_HPP #define MIGRAPHX_GUARD_OPERATORS_UNARY_HPP
#include <array> #include <migraphx/op/name.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct unary template <class Derived>
struct unary : op_name<Derived>
{ {
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
return inputs.at(0); return inputs.at(0);
} }
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
if(input.get_shape().standard())
{
std::transform(input.begin(),
input.end(),
output.begin(),
static_cast<const Derived&>(*this).apply());
}
else
{
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input(idx.begin(), idx.end()));
});
}
});
return result;
}
}; };
} // namespace op } // namespace op
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment