"examples/vscode:/vscode.git/clone" did not exist on "2411fa28f053beea411f7fd595f181065008291f"
Commit 65e14286 authored by charlie's avatar charlie
Browse files

Unary ops changes and tests

parent 30243d2c
...@@ -32,21 +32,20 @@ namespace migraphx { ...@@ -32,21 +32,20 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct elu struct elu : unary<elu>
{ {
std::string name() const { return "elu"; }
float alpha = 1; float alpha = 1;
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
return inputs.front();
}
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.alpha, "alpha")); return pack(f(self.alpha, "alpha"));
} }
auto apply() const
{
return [&](auto x) { return x > 0 ? x : alpha * std::expm1(x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -26,12 +26,13 @@ ...@@ -26,12 +26,13 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/op/unary.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct leaky_relu struct leaky_relu : unary<leaky_relu>
{ {
float alpha = 0.01; float alpha = 0.01;
...@@ -42,10 +43,10 @@ struct leaky_relu ...@@ -42,10 +43,10 @@ struct leaky_relu
} }
std::string name() const { return "leaky_relu"; } std::string name() const { return "leaky_relu"; }
shape compute_shape(std::vector<shape> inputs) const
auto apply() const
{ {
check_shapes{inputs, *this}.has(1); return [&](auto x) { return x > 0 ? x : x * alpha; };
return inputs.front();
} }
}; };
......
...@@ -507,65 +507,6 @@ struct ref_quant_gemm ...@@ -507,65 +507,6 @@ struct ref_quant_gemm
}; };
MIGRAPHX_REGISTER_OP(ref_gemm) MIGRAPHX_REGISTER_OP(ref_gemm)
struct leaky_relu_op
{
op::leaky_relu op;
std::string name() const { return "ref::leaky_relu"; }
auto fcn() const
{
auto a = op.alpha;
return [a](auto x) { return x > 0 ? x : x * a; };
}
};
struct elu_op
{
op::elu op;
std::string name() const { return "ref::elu"; }
auto fcn() const
{
auto a = op.alpha;
return [a](auto x) { return x > 0 ? x : a * std::expm1(x); };
}
};
template <typename Op>
struct ref_unary : auto_register_op<ref_unary<Op>>
{
ref_unary() = default;
template <class T>
ref_unary(T pop) : op(Op{std::move(pop)})
{
}
Op op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op.op, f);
}
std::string name() const { return op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(1);
const auto& s = inputs.at(0);
return {s.type(), s.lens()};
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
assert(input.get_shape().standard());
std::transform(input.begin(), input.end(), output.begin(), op.fcn());
});
return result;
}
};
template <class Op> template <class Op>
struct ref_softmax : auto_register_op<ref_softmax<Op>> struct ref_softmax : auto_register_op<ref_softmax<Op>>
{ {
...@@ -708,9 +649,7 @@ struct ref_apply ...@@ -708,9 +649,7 @@ struct ref_apply
apply_map["quant_dot"] = extend_op<ref_quant_gemm, op::quant_dot>(); apply_map["quant_dot"] = extend_op<ref_quant_gemm, op::quant_dot>();
apply_map["quant_convolution"] = apply_map["quant_convolution"] =
extend_op<ref_convolution<op::quant_convolution>, op::quant_convolution>(); extend_op<ref_convolution<op::quant_convolution>, op::quant_convolution>();
apply_map["elu"] = extend_op<ref_unary<elu_op>, op::elu>();
apply_map["im2col"] = extend_op<ref_im2col, op::im2col>(); apply_map["im2col"] = extend_op<ref_im2col, op::im2col>();
apply_map["leaky_relu"] = extend_op<ref_unary<leaky_relu_op>, op::leaky_relu>();
apply_map["logsoftmax"] = extend_op<ref_softmax<op::logsoftmax>, op::logsoftmax>(); apply_map["logsoftmax"] = extend_op<ref_softmax<op::logsoftmax>, op::logsoftmax>();
apply_map["lrn"] = extend_op<ref_lrn, op::lrn>(); apply_map["lrn"] = extend_op<ref_lrn, op::lrn>();
apply_map["pad"] = extend_op<ref_pad, op::pad>(); apply_map["pad"] = extend_op<ref_pad, op::pad>();
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment