Commit 48c7c810 authored by charlie's avatar charlie
Browse files

Fix elu and leaky_relu pointwise JIT

parent 5ba8cdf6
......@@ -36,6 +36,11 @@ struct elu : unary<elu>
{
float alpha = 1;
std::string point_op() const
{
return "${function:where}(${0} > 0, ${0}, ${alpha} * (migraphx::exp(${0}) - 1))";
}
template <class Self, class F>
static auto reflect(Self& self, F f)
{
......
......@@ -42,6 +42,8 @@ struct leaky_relu : unary<leaky_relu>
return pack(f(self.alpha, "alpha"));
}
std::string point_op() const { return "${function:where}(${0} > 0, ${0}, ${alpha} * ${0})"; }
std::string name() const { return "leaky_relu"; }
auto apply() const
......
......@@ -216,55 +216,6 @@ struct cpu_pad
};
MIGRAPHX_REGISTER_OP(cpu_pad)
struct leaky_relu_op
{
op::leaky_relu op;
std::string name() const { return "cpu::leaky_relu"; }
auto fcn() const
{
auto a = op.alpha;
return [a](auto x) { return x > 0 ? x : x * a; };
}
};
template <typename Op>
struct cpu_unary2 : auto_register_op<cpu_unary2<Op>>
{
cpu_unary2() = default;
template <class T>
cpu_unary2(T pop) : op(Op{std::move(pop)})
{
}
Op op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op.op, f);
}
std::string name() const { return op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(1);
const auto& s = inputs.at(0);
return {s.type(), s.lens()};
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
assert(input.get_shape().standard());
std::transform(input.begin(), input.end(), output.begin(), op.fcn());
});
return result;
}
};
template struct cpu_unary2<leaky_relu_op>;
struct cpu_rnn_var_sl_last_output
{
op::rnn_var_sl_last_output op;
......
......@@ -96,9 +96,9 @@ struct miopen_apply
add_extend_op("argmax");
add_extend_op("argmin");
add_extend_op("elu");
// add_extend_op("elu");
add_extend_op("gather");
add_extend_op("leaky_relu");
// add_extend_op("leaky_relu");
add_extend_op("logsoftmax");
add_extend_op("lrn");
add_extend_op("multinomial");
......
......@@ -34,7 +34,7 @@ struct test_elu : verify_program<test_elu>
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
mm->add_instruction(migraphx::make_op("elu", {{"alpha", 1.0}}), x);
mm->add_instruction(migraphx::make_op("elu", {{"alpha", 0.8}}), x);
return p;
}
};
......@@ -34,7 +34,7 @@ struct test_leaky_relu : verify_program<test_leaky_relu>
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
mm->add_instruction(migraphx::make_op("leaky_relu", {{"alpha", 0.01}}), x);
mm->add_instruction(migraphx::make_op("leaky_relu", {{"alpha", 0.41}}), x);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment