Commit d70800c5 authored by Scott Thornton's avatar Scott Thornton
Browse files

Added softmax operator cpu implementation

parent a7e678ca
...@@ -7,6 +7,9 @@ ...@@ -7,6 +7,9 @@
namespace rtg { namespace rtg {
namespace cpu { namespace cpu {
template <typename T>
T zero(const T& x) { return T(0); }
struct cpu_convolution struct cpu_convolution
{ {
convolution op; convolution op;
...@@ -83,6 +86,7 @@ struct cpu_gemm ...@@ -83,6 +86,7 @@ struct cpu_gemm
} }
} }
}); });
return C;
} }
}; };
...@@ -140,11 +144,6 @@ struct atan_op ...@@ -140,11 +144,6 @@ struct atan_op
auto fcn() { return [](auto x) { return std::atan(x); }; } auto fcn() { return [](auto x) { return std::atan(x); }; }
}; };
struct softmax_op
{
std::string name() const {return "cpu::softmax"; }
};
struct tanh_op struct tanh_op
{ {
std::string name() const {return "cpu::tanh"; } std::string name() const {return "cpu::tanh"; }
...@@ -173,7 +172,7 @@ template <typename Op> ...@@ -173,7 +172,7 @@ template <typename Op>
struct cpu_unary struct cpu_unary
{ {
Op op; Op op;
std::string name() const { op.name(); } std::string name() const { return op.name(); }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); } shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
...@@ -187,6 +186,26 @@ struct cpu_unary ...@@ -187,6 +186,26 @@ struct cpu_unary
} }
}; };
struct softmax
{
std::string name() const { return "cpu::softmax"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
result.visit([&](auto output) {
args[0].visit([&](auto input) {
std::transform(input.begin(), input.end(), output.begin(),
[](auto x) { return std::exp(x); });
float t = std::accumulate(output.begin(), output.end(), zero(input.front()));
std::transform(output.begin(), output.end(), output.begin(),
[t](auto x) { return x/t; });
});
});
return result;
}
};
struct add_op struct add_op
{ {
std::string name() const { return "add"; } std::string name() const { return "add"; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment