Commit 81b0631e authored by Scott Thornton's avatar Scott Thornton
Browse files

Added softmax2d cpu operator and test

parent dd653b52
...@@ -351,7 +351,7 @@ struct softmax ...@@ -351,7 +351,7 @@ struct softmax
std::string name() const {return "softmax"; } std::string name() const {return "softmax"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1).only_dims(4);
return inputs.at(0); return inputs.at(0);
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
......
...@@ -11,6 +11,7 @@ namespace rtg { ...@@ -11,6 +11,7 @@ namespace rtg {
template <class T> template <class T>
struct tensor_view struct tensor_view
{ {
using value_type = T;
tensor_view() : m_data(nullptr) {} tensor_view() : m_data(nullptr) {}
tensor_view(shape s, T* d) : m_data(d), m_shape(s) {} tensor_view(shape s, T* d) : m_data(d), m_shape(s) {}
......
...@@ -61,8 +61,8 @@ struct cpu_gemm ...@@ -61,8 +61,8 @@ struct cpu_gemm
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
argument C{output_shape}; argument result{output_shape};
visit_all(C, args[0], args[1])([&](auto C, auto A, auto B) { visit_all(result, args[0], args[1])([&](auto C, auto A, auto B) {
auto M = A.get_shape().lens()[0]; auto M = A.get_shape().lens()[0];
auto N = B.get_shape().lens()[1]; auto N = B.get_shape().lens()[1];
auto K = B.get_shape().lens()[0]; auto K = B.get_shape().lens()[0];
...@@ -86,7 +86,7 @@ struct cpu_gemm ...@@ -86,7 +86,7 @@ struct cpu_gemm
} }
} }
}); });
return C; return result;
} }
}; };
...@@ -186,21 +186,43 @@ struct cpu_unary ...@@ -186,21 +186,43 @@ struct cpu_unary
} }
}; };
struct softmax struct softmax2d
{ {
std::string name() const { return "cpu::softmax"; } std::string name() const { return "cpu::softmax2d"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); } shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
result.visit([&](auto output) { visit_all(result, args[0])([&](auto output, auto input) {
args[0].visit([&](auto input) { using value_type = typename decltype(input)::value_type;
std::transform(input.begin(), input.end(), output.begin(), auto nb = input.get_shape().lens()[0];
[](auto x) { return std::exp(x); }); auto nc = input.get_shape().lens()[1];
float t = std::accumulate(output.begin(), output.end(), zero(input.front())); auto nh = input.get_shape().lens()[2];
std::transform(output.begin(), output.end(), output.begin(), auto nw = input.get_shape().lens()[3];
[t](auto x) { return x/t; }); for (int b = 0; b < nb; b++) {
}); for (int i = 0; i < nh; i++) {
for (int j = 0; j < nw; j++) {
value_type cmax = std::numeric_limits<value_type>::lowest();
for (int c = 0; c < nc; c++) {
cmax = std::max(cmax, input(b, c, i, j));
}
for (int c = 0; c < nc; c++) {
output(b, c, i, j) = std::exp(input(b, c, i, j)-cmax);
}
value_type sum = value_type(0);
for (int c = 0; c < nc; c++) {
sum += output(b, c, i, j);
}
for (int c = 0; c < nc; c++) {
output(b, c, i, j) = output(b, c, i, j)/sum;
}
// for (int c = 0; c < nc; c++) {
// output(b, c, i, j) = input(b, c, i, j);
// }
}
}
}
}); });
return result; return result;
} }
...@@ -333,7 +355,7 @@ struct cpu_apply ...@@ -333,7 +355,7 @@ struct cpu_apply
void apply_softmax(instruction_ref ins) void apply_softmax(instruction_ref ins)
{ {
auto&& op = any_cast<softmax>(ins->op); auto&& op = any_cast<softmax>(ins->op);
prog->replace_instruction(ins, softmax{}, ins->arguments); prog->replace_instruction(ins, softmax2d{}, ins->arguments);
} }
void apply_tanh(instruction_ref ins) void apply_tanh(instruction_ref ins)
......
...@@ -54,8 +54,90 @@ void gemm_test() { ...@@ -54,8 +54,90 @@ void gemm_test() {
} }
} }
void softmax_test() {
rtg::program p;
std::vector<float> A = {-5.61869681e-01, 9.07827199e-01, 1.29255986e+00,
3.18533443e-02, -1.22183852e-03, -2.83830553e-01,
-1.03245842e+00, -9.28322077e-01, -8.82696748e-01,
1.11327164e-01, -9.20038462e-01, 8.47388089e-01,
2.51734018e-01, 1.50563884e+00, 2.23056650e+00,
-6.17576987e-02, -1.00264274e-01, -6.10369384e-01,
1.17537189e+00, -2.51560897e-01, -8.50333512e-01,
-8.03578615e-01, -6.51194930e-01, -2.58137047e-01,
4.65528190e-01, 3.23284641e-02, -1.54700470e+00,
1.38096774e+00, 5.39869189e-01, -7.56884992e-01,
1.81503093e+00, -2.11269641e+00, 1.92466557e+00,
1.77230799e+00, 2.21660900e+00, 1.56777036e+00,
-2.08995026e-03, 3.50566894e-01, -1.15042710e+00,
-1.18577778e+00, 8.90633047e-01, -6.63949102e-02,
1.44661188e+00, 1.59215283e+00, -2.56262213e-01,
9.39079225e-01, 4.07298543e-02, 3.86590779e-01,
6.09607756e-01, 8.22331488e-01, -2.82126725e-01,
-9.49052632e-01, -4.24012303e-01, -5.32990396e-01,
-3.18386006e+00, 3.27092171e-01, -1.33315325e+00,
3.62459183e-01, 3.74710828e-01, -1.30302286e+00,
1.79680198e-01, -4.51832324e-01, 4.34282750e-01,
-7.09520102e-01, 6.20333970e-01, -1.28712380e+00,
2.04130828e-01, -7.70607769e-01, 1.61889160e+00,
-1.50951004e+00, -4.10505563e-01, -3.56566496e-02,
-1.29747534e+00, -1.49967879e-01, 7.77626812e-01,
-8.28408226e-02, 2.73412596e-02, 5.79780899e-03,
9.87900198e-02, -7.95276761e-01, -1.38536084e+00,
-6.63573861e-01, 3.89783204e-01, -1.30670881e+00,
-7.62425125e-01, -4.04883057e-01, 6.24344349e-01,
3.68128955e-01, -1.01577950e+00, -3.06715906e-01,
5.67961395e-01, 2.98198581e-01, -1.63613629e+00,
-3.75131965e-01, -6.75393403e-01, 2.59172034e+00,
6.75538957e-01, 9.07939598e-02, 1.92257717e-01,
-1.21592450e+00, -2.73682117e-01, 1.25232983e+00,
-1.39969170e+00, -1.91483587e-01, 2.57732719e-01,
3.10056299e-01, 1.41833842e+00, -1.81386679e-01,
3.92868072e-01, -8.14771175e-01, 2.02392387e+00,
-9.42091495e-02, -3.77683818e-01, 2.05638766e+00,
2.93796062e-01, -6.02131486e-01, 2.70461679e-01,
-8.92358482e-01, 1.04388881e+00, 2.66154885e-01};
std::vector<float> S = {0.30191708, 0.59879845, 0.50029165, 0.24915339, 0.36823985,
0.13190967, 0.0349741 , 0.18750034, 0.21905553, 0.27000085,
0.0547399 , 0.56318235, 0.47422904, 0.78964758, 0.91381913,
0.44601166, 0.47902739, 0.13120073, 0.4449684 , 0.18766427,
0.15753111, 0.07844277, 0.05120674, 0.36648798, 0.14637007,
0.13152322, 0.01560997, 0.29065287, 0.49196178, 0.10550152,
0.81890774, 0.06369215, 0.62972021, 0.74931765, 0.67285055,
0.35034987, 0.28612873, 0.31931475, 0.04220394, 0.16093165,
0.22390974, 0.11915915, 0.3115395 , 0.35899726, 0.22190949,
0.57518375, 0.13888834, 0.7753762 , 0.4642328 , 0.57055861,
0.21954368, 0.34515455, 0.09486015, 0.40631217, 0.01842281,
0.48770609, 0.06652815, 0.36023033, 0.42343026, 0.24226256,
0.17348589, 0.44066274, 0.6865865 , 0.17296699, 0.46923906,
0.06921105, 0.3570261 , 0.4125829 , 0.73165393, 0.15302512,
0.29499072, 0.33932695, 0.30852377, 0.40762195, 0.40170741,
0.36259529, 0.60848355, 0.42618036, 0.31721094, 0.02960522,
0.28256637, 0.24389413, 0.2725659 , 0.10663581, 0.27622163,
0.28264219, 0.53652936, 0.09476089, 0.40890986, 0.34848392,
0.32572666, 0.53076893, 0.11529481, 0.29117745, 0.14625968,
0.8756339 , 0.49818122, 0.10656087, 0.1813329 , 0.17664003,
0.21410346, 0.80408043, 0.02315119, 0.27155462, 0.32804728,
0.13268511, 0.61795473, 0.49703068, 0.41696799, 0.10175809,
0.71028161, 0.29929739, 0.17377149, 0.76075399, 0.20071237,
0.32632929, 0.36892858, 0.09416146, 0.26656723, 0.42914796};
rtg::shape a_shape{rtg::shape::float_type, {5,3,4,2}};
auto a = p.add_literal(rtg::literal{a_shape, A});
p.add_instruction(rtg::softmax{}, a);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(120);
memcpy(results_vector.data(), result.data(), 120*sizeof(float));
float tol = 1e-6;
for (int i = 0; i < results_vector.size(); i++) {
assert(std::abs(results_vector[i]-S[i]) < tol);
}
}
int main() int main()
{ {
exp_test(); exp_test();
gemm_test(); gemm_test();
softmax_test();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment