Commit 08c3a87f authored by Paul's avatar Paul
Browse files

Update unary operator

parent 79b2b1fc
......@@ -10,6 +10,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/make_signed.hpp>
#include <cmath>
#include <utility>
......@@ -17,9 +18,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct abs : unary
struct abs : unary<abs>
{
std::string name() const { return "abs"; }
auto apply() const
{
return [](auto x) { return std::abs(make_signed(x)); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct acos : unary
struct acos : unary<acos>
{
std::string name() const { return "acos"; }
auto apply() const
{
return [](auto x) { return std::acos(x); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct asin : unary
struct asin : unary<asin>
{
std::string name() const { return "asin"; }
auto apply() const
{
return [](auto x) { return std::asin(x); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct atan : unary
struct atan : unary<atan>
{
std::string name() const { return "atan"; }
auto apply() const
{
return [](auto x) { return std::atan(x); };
}
};
} // namespace op
......
#ifndef MIGRAPHX_GUARD_OPERATORS_BINARY_HPP
#define MIGRAPHX_GUARD_OPERATORS_BINARY_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
#include <migraphx/op/name.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
template <class Derived>
struct binary
struct binary : op_name<Derived>
{
std::string name() const
{
static const std::string& name = get_type_name<Derived>();
return name.substr(name.rfind("::") + 2);
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).same_type().same_dims();
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct cos : unary
struct cos : unary<cos>
{
std::string name() const { return "cos"; }
auto apply() const
{
return [](auto x) { return std::cos(x); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct cosh : unary
struct cosh : unary<cosh>
{
std::string name() const { return "cosh"; }
auto apply() const
{
return [](auto x) { return std::cosh(x); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct exp : unary
struct exp : unary<exp>
{
std::string name() const { return "exp"; }
auto apply() const
{
return [](auto x) { return std::exp(x); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct log : unary
struct log : unary<log>
{
std::string name() const { return "log"; }
auto apply() const
{
return [](auto x) { return std::log(x); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct neg : unary
struct neg : unary<neg>
{
std::string name() const { return "neg"; }
auto apply() const
{
return [](auto x) { return -x; };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct relu : unary
struct relu : unary<relu>
{
std::string name() const { return "relu"; }
auto apply() const
{
return [](auto x) { return std::max(decltype(x){0}, x); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct sigmoid : unary
struct sigmoid : unary<sigmoid>
{
std::string name() const { return "sigmoid"; }
auto apply() const
{
return [](auto x) { return 1.f / (1.f + std::exp(-x)); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct sin : unary
struct sin : unary<sin>
{
std::string name() const { return "sin"; }
auto apply() const
{
return [](auto x) { return std::sin(x); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct sinh : unary
struct sinh : unary<sinh>
{
std::string name() const { return "sinh"; }
auto apply() const
{
return [](auto x) { return std::sinh(x); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct tan : unary
struct tan : unary<tan>
{
std::string name() const { return "tan"; }
auto apply() const
{
return [](auto x) { return std::tan(x); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct tanh : unary
struct tanh : unary<tanh>
{
std::string name() const { return "tanh"; }
auto apply() const
{
return [](auto x) { return std::tanh(x); };
}
};
} // namespace op
......
#ifndef MIGRAPHX_GUARD_OPERATORS_UNARY_HPP
#define MIGRAPHX_GUARD_OPERATORS_UNARY_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
#include <migraphx/op/name.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct unary
template <class Derived>
struct unary : op_name<Derived>
{
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
if(input.get_shape().standard())
{
std::transform(input.begin(),
input.end(),
output.begin(),
static_cast<const Derived&>(*this).apply());
}
else
{
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input(idx.begin(), idx.end()));
});
}
});
return result;
}
};
} // namespace op
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment