Commit 65e14286 authored by charlie's avatar charlie
Browse files

Unary ops changes and tests

parent 30243d2c
......@@ -32,21 +32,20 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct elu
struct elu : unary<elu>
{
std::string name() const { return "elu"; }
float alpha = 1;
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
return inputs.front();
}
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.alpha, "alpha"));
}
auto apply() const
{
return [&](auto x) { return x > 0 ? x : alpha * std::expm1(x); };
}
};
} // namespace op
......
......@@ -26,12 +26,13 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/unary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct leaky_relu
struct leaky_relu : unary<leaky_relu>
{
float alpha = 0.01;
......@@ -42,10 +43,10 @@ struct leaky_relu
}
std::string name() const { return "leaky_relu"; }
shape compute_shape(std::vector<shape> inputs) const
auto apply() const
{
check_shapes{inputs, *this}.has(1);
return inputs.front();
return [&](auto x) { return x > 0 ? x : x * alpha; };
}
};
......
......@@ -507,65 +507,6 @@ struct ref_quant_gemm
};
MIGRAPHX_REGISTER_OP(ref_gemm)
struct leaky_relu_op
{
op::leaky_relu op;
std::string name() const { return "ref::leaky_relu"; }
auto fcn() const
{
auto a = op.alpha;
return [a](auto x) { return x > 0 ? x : x * a; };
}
};
struct elu_op
{
op::elu op;
std::string name() const { return "ref::elu"; }
auto fcn() const
{
auto a = op.alpha;
return [a](auto x) { return x > 0 ? x : a * std::expm1(x); };
}
};
template <typename Op>
struct ref_unary : auto_register_op<ref_unary<Op>>
{
ref_unary() = default;
template <class T>
ref_unary(T pop) : op(Op{std::move(pop)})
{
}
Op op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op.op, f);
}
std::string name() const { return op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(1);
const auto& s = inputs.at(0);
return {s.type(), s.lens()};
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
assert(input.get_shape().standard());
std::transform(input.begin(), input.end(), output.begin(), op.fcn());
});
return result;
}
};
template <class Op>
struct ref_softmax : auto_register_op<ref_softmax<Op>>
{
......@@ -708,9 +649,7 @@ struct ref_apply
apply_map["quant_dot"] = extend_op<ref_quant_gemm, op::quant_dot>();
apply_map["quant_convolution"] =
extend_op<ref_convolution<op::quant_convolution>, op::quant_convolution>();
apply_map["elu"] = extend_op<ref_unary<elu_op>, op::elu>();
apply_map["im2col"] = extend_op<ref_im2col, op::im2col>();
apply_map["leaky_relu"] = extend_op<ref_unary<leaky_relu_op>, op::leaky_relu>();
apply_map["logsoftmax"] = extend_op<ref_softmax<op::logsoftmax>, op::logsoftmax>();
apply_map["lrn"] = extend_op<ref_lrn, op::lrn>();
apply_map["pad"] = extend_op<ref_pad, op::pad>();
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment