Commit a7e678ca authored by Scott Thornton's avatar Scott Thornton
Browse files

Added bodies of unary and binary operators

parent 231b7edd
......@@ -86,23 +86,145 @@ struct cpu_gemm
}
};
struct relu
struct identity_op
{
std::string name() const { return "cpu::relu"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
std::string name() const {return "cpu::identity"; }
auto fcn() { return [](auto x) { return x; }; }
};
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
result.visit([&](auto output) {
args[0].visit([&](auto input) {
std::transform(input.begin(), input.end(), output.begin(), [](auto x) {
return x > 0 ? x : 0;
});
});
});
return result;
}
struct abs_op
{
std::string name() const {return "cpu::abs"; }
auto fcn() { return [](auto x) { return std::abs(x); }; }
};
struct exp_op
{
std::string name() const {return "cpu::exp"; }
auto fcn() { return [](auto x) { return std::exp(x); }; }
};
struct sin_op
{
std::string name() const {return "cpu::sin"; }
auto fcn() { return [](auto x) { return std::sin(x); }; }
};
struct cos_op
{
std::string name() const {return "cpu::cos"; }
auto fcn() { return [](auto x) { return std::cos(x); }; }
};
struct tan_op
{
std::string name() const {return "cpu::tan"; }
auto fcn() { return [](auto x) { return std::tan(x); }; }
};
struct asin_op
{
std::string name() const {return "cpu::asin"; }
auto fcn() { return [](auto x) { return std::asin(x); }; }
};
struct acos_op
{
std::string name() const {return "cpu::acos"; }
auto fcn() { return [](auto x) { return std::acos(x); }; }
};
struct atan_op
{
std::string name() const {return "cpu::atan"; }
auto fcn() { return [](auto x) { return std::atan(x); }; }
};
struct softmax_op
{
std::string name() const {return "cpu::softmax"; }
};
struct tanh_op
{
std::string name() const {return "cpu::tanh"; }
auto fcn() { return [](auto x) { return std::tanh(x); }; }
};
struct sigmoid_op
{
std::string name() const {return "cpu::sigmoid"; }
auto fcn() { return [](auto x) { return 1.f/(1.f + std::exp(-x)); }; }
};
struct neg_op
{
std::string name() const {return "cpu::neg"; }
auto fcn() { return [](auto x) { return -x; }; }
};
struct relu_op
{
std::string name() const {return "cpu::relu"; }
auto fcn() const { return [](auto x) { return x > 0 ? x : 0; }; }
};
template <typename Op>
struct cpu_unary
{
Op op;
std::string name() const { op.name(); }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
result.visit([&](auto output) {
args[0].visit([&](auto input) {
std::transform(input.begin(), input.end(), output.begin(), op.fcn());
});
});
return result;
}
};
struct add_op
{
std::string name() const { return "add"; }
auto fcn() const { return [](auto x, auto y) {return x + y;};}
};
struct sub_op
{
std::string name() const { return "sub"; }
auto fcn() const { return [](auto x, auto y) {return x - y;};}
};
struct mul_op
{
std::string name() const { return "mul"; }
auto fcn() const { return [](auto x, auto y) {return x * y;};}
};
struct div_op
{
std::string name() const { return "div"; }
auto fcn() const { return [](auto x, auto y) {return x / y;};}
};
template <typename Op>
struct cpu_binary
{
Op op;
std::string name() const { op.name(); }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
std::transform(input1.begin(), input1.end(), input2.begin(), output.begin(), op.fcn());
});
return result;
}
};
struct cpu_apply
......@@ -134,7 +256,7 @@ struct cpu_apply
{
auto&& op = any_cast<activation>(ins->op);
if(op.mode == "relu")
prog->replace_instruction(ins, relu{}, ins->arguments);
prog->replace_instruction(ins, cpu_unary<relu_op>{}, ins->arguments);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment