Commit 81b0631e authored by Scott Thornton's avatar Scott Thornton
Browse files

Added softmax2d cpu operator and test

parent dd653b52
......@@ -351,7 +351,7 @@ struct softmax
std::string name() const {return "softmax"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
check_shapes{inputs}.has(1).only_dims(4);
return inputs.at(0);
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
......
......@@ -11,6 +11,7 @@ namespace rtg {
template <class T>
struct tensor_view
{
using value_type = T;
tensor_view() : m_data(nullptr) {}
tensor_view(shape s, T* d) : m_data(d), m_shape(s) {}
......
......@@ -61,8 +61,8 @@ struct cpu_gemm
argument compute(shape output_shape, std::vector<argument> args) const
{
argument C{output_shape};
visit_all(C, args[0], args[1])([&](auto C, auto A, auto B) {
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto C, auto A, auto B) {
auto M = A.get_shape().lens()[0];
auto N = B.get_shape().lens()[1];
auto K = B.get_shape().lens()[0];
......@@ -86,7 +86,7 @@ struct cpu_gemm
}
}
});
return C;
return result;
}
};
......@@ -186,21 +186,43 @@ struct cpu_unary
}
};
struct softmax
struct softmax2d
{
std::string name() const { return "cpu::softmax"; }
std::string name() const { return "cpu::softmax2d"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
result.visit([&](auto output) {
args[0].visit([&](auto input) {
std::transform(input.begin(), input.end(), output.begin(),
[](auto x) { return std::exp(x); });
float t = std::accumulate(output.begin(), output.end(), zero(input.front()));
std::transform(output.begin(), output.end(), output.begin(),
[t](auto x) { return x/t; });
});
visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(input)::value_type;
auto nb = input.get_shape().lens()[0];
auto nc = input.get_shape().lens()[1];
auto nh = input.get_shape().lens()[2];
auto nw = input.get_shape().lens()[3];
for (int b = 0; b < nb; b++) {
for (int i = 0; i < nh; i++) {
for (int j = 0; j < nw; j++) {
value_type cmax = std::numeric_limits<value_type>::lowest();
for (int c = 0; c < nc; c++) {
cmax = std::max(cmax, input(b, c, i, j));
}
for (int c = 0; c < nc; c++) {
output(b, c, i, j) = std::exp(input(b, c, i, j)-cmax);
}
value_type sum = value_type(0);
for (int c = 0; c < nc; c++) {
sum += output(b, c, i, j);
}
for (int c = 0; c < nc; c++) {
output(b, c, i, j) = output(b, c, i, j)/sum;
}
// for (int c = 0; c < nc; c++) {
// output(b, c, i, j) = input(b, c, i, j);
// }
}
}
}
});
return result;
}
......@@ -333,7 +355,7 @@ struct cpu_apply
void apply_softmax(instruction_ref ins)
{
auto&& op = any_cast<softmax>(ins->op);
prog->replace_instruction(ins, softmax{}, ins->arguments);
prog->replace_instruction(ins, softmax2d{}, ins->arguments);
}
void apply_tanh(instruction_ref ins)
......
......@@ -54,8 +54,90 @@ void gemm_test() {
}
}
void softmax_test() {
rtg::program p;
std::vector<float> A = {-5.61869681e-01, 9.07827199e-01, 1.29255986e+00,
3.18533443e-02, -1.22183852e-03, -2.83830553e-01,
-1.03245842e+00, -9.28322077e-01, -8.82696748e-01,
1.11327164e-01, -9.20038462e-01, 8.47388089e-01,
2.51734018e-01, 1.50563884e+00, 2.23056650e+00,
-6.17576987e-02, -1.00264274e-01, -6.10369384e-01,
1.17537189e+00, -2.51560897e-01, -8.50333512e-01,
-8.03578615e-01, -6.51194930e-01, -2.58137047e-01,
4.65528190e-01, 3.23284641e-02, -1.54700470e+00,
1.38096774e+00, 5.39869189e-01, -7.56884992e-01,
1.81503093e+00, -2.11269641e+00, 1.92466557e+00,
1.77230799e+00, 2.21660900e+00, 1.56777036e+00,
-2.08995026e-03, 3.50566894e-01, -1.15042710e+00,
-1.18577778e+00, 8.90633047e-01, -6.63949102e-02,
1.44661188e+00, 1.59215283e+00, -2.56262213e-01,
9.39079225e-01, 4.07298543e-02, 3.86590779e-01,
6.09607756e-01, 8.22331488e-01, -2.82126725e-01,
-9.49052632e-01, -4.24012303e-01, -5.32990396e-01,
-3.18386006e+00, 3.27092171e-01, -1.33315325e+00,
3.62459183e-01, 3.74710828e-01, -1.30302286e+00,
1.79680198e-01, -4.51832324e-01, 4.34282750e-01,
-7.09520102e-01, 6.20333970e-01, -1.28712380e+00,
2.04130828e-01, -7.70607769e-01, 1.61889160e+00,
-1.50951004e+00, -4.10505563e-01, -3.56566496e-02,
-1.29747534e+00, -1.49967879e-01, 7.77626812e-01,
-8.28408226e-02, 2.73412596e-02, 5.79780899e-03,
9.87900198e-02, -7.95276761e-01, -1.38536084e+00,
-6.63573861e-01, 3.89783204e-01, -1.30670881e+00,
-7.62425125e-01, -4.04883057e-01, 6.24344349e-01,
3.68128955e-01, -1.01577950e+00, -3.06715906e-01,
5.67961395e-01, 2.98198581e-01, -1.63613629e+00,
-3.75131965e-01, -6.75393403e-01, 2.59172034e+00,
6.75538957e-01, 9.07939598e-02, 1.92257717e-01,
-1.21592450e+00, -2.73682117e-01, 1.25232983e+00,
-1.39969170e+00, -1.91483587e-01, 2.57732719e-01,
3.10056299e-01, 1.41833842e+00, -1.81386679e-01,
3.92868072e-01, -8.14771175e-01, 2.02392387e+00,
-9.42091495e-02, -3.77683818e-01, 2.05638766e+00,
2.93796062e-01, -6.02131486e-01, 2.70461679e-01,
-8.92358482e-01, 1.04388881e+00, 2.66154885e-01};
std::vector<float> S = {0.30191708, 0.59879845, 0.50029165, 0.24915339, 0.36823985,
0.13190967, 0.0349741 , 0.18750034, 0.21905553, 0.27000085,
0.0547399 , 0.56318235, 0.47422904, 0.78964758, 0.91381913,
0.44601166, 0.47902739, 0.13120073, 0.4449684 , 0.18766427,
0.15753111, 0.07844277, 0.05120674, 0.36648798, 0.14637007,
0.13152322, 0.01560997, 0.29065287, 0.49196178, 0.10550152,
0.81890774, 0.06369215, 0.62972021, 0.74931765, 0.67285055,
0.35034987, 0.28612873, 0.31931475, 0.04220394, 0.16093165,
0.22390974, 0.11915915, 0.3115395 , 0.35899726, 0.22190949,
0.57518375, 0.13888834, 0.7753762 , 0.4642328 , 0.57055861,
0.21954368, 0.34515455, 0.09486015, 0.40631217, 0.01842281,
0.48770609, 0.06652815, 0.36023033, 0.42343026, 0.24226256,
0.17348589, 0.44066274, 0.6865865 , 0.17296699, 0.46923906,
0.06921105, 0.3570261 , 0.4125829 , 0.73165393, 0.15302512,
0.29499072, 0.33932695, 0.30852377, 0.40762195, 0.40170741,
0.36259529, 0.60848355, 0.42618036, 0.31721094, 0.02960522,
0.28256637, 0.24389413, 0.2725659 , 0.10663581, 0.27622163,
0.28264219, 0.53652936, 0.09476089, 0.40890986, 0.34848392,
0.32572666, 0.53076893, 0.11529481, 0.29117745, 0.14625968,
0.8756339 , 0.49818122, 0.10656087, 0.1813329 , 0.17664003,
0.21410346, 0.80408043, 0.02315119, 0.27155462, 0.32804728,
0.13268511, 0.61795473, 0.49703068, 0.41696799, 0.10175809,
0.71028161, 0.29929739, 0.17377149, 0.76075399, 0.20071237,
0.32632929, 0.36892858, 0.09416146, 0.26656723, 0.42914796};
rtg::shape a_shape{rtg::shape::float_type, {5,3,4,2}};
auto a = p.add_literal(rtg::literal{a_shape, A});
p.add_instruction(rtg::softmax{}, a);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(120);
memcpy(results_vector.data(), result.data(), 120*sizeof(float));
float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) {
assert(std::abs(results_vector[i]-S[i]) < tol);
}
}
int main()
{
exp_test();
gemm_test();
softmax_test();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment