Commit 3a0c7c77 authored by Scott Thornton's avatar Scott Thornton
Browse files

Fixed up computing shape for pooling

parent bff0223b
...@@ -183,21 +183,18 @@ struct pooling ...@@ -183,21 +183,18 @@ struct pooling
assert(lengths[0] < (input.lens()[2] + 2 * padding[0])); assert(lengths[0] < (input.lens()[2] + 2 * padding[0]));
assert(lengths[1] < (input.lens()[3] + 2 * padding[1])); assert(lengths[1] < (input.lens()[3] + 2 * padding[1]));
return {t, return {
{ t,
input.lens()[0], {
input.lens()[1], input.lens()[0],
std::size_t(std::max<std::ptrdiff_t>( input.lens()[1],
1, std::size_t(std::max<std::ptrdiff_t>(
std::ceil((input.lens()[2] + 2 * padding[0] - lengths[0]) / 1, (input.lens()[2] + 2 * padding[0] - lengths[0]) / stride[0]) +
static_cast<float>(stride[0])) + 1),
1)), std::size_t(std::max<std::ptrdiff_t>(
std::size_t(std::max<std::ptrdiff_t>( 1, (input.lens()[3] + 2 * padding[1] - lengths[1]) / stride[1]) +
1, 1),
std::ceil((input.lens()[3] + 2 * padding[1] - lengths[1]) / }};
static_cast<float>(stride[1])) +
1)),
}};
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
...@@ -320,7 +317,7 @@ struct gemm ...@@ -320,7 +317,7 @@ struct gemm
std::string name() const { return "gemm"; } std::string name() const { return "gemm"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(2).same_type(); check_shapes{inputs}.has(2).same_type().same_ndims().only_dims(2);
const shape& a = inputs.at(0); const shape& a = inputs.at(0);
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
...@@ -431,6 +428,7 @@ struct broadcast ...@@ -431,6 +428,7 @@ struct broadcast
auto input = inputs.at(1); auto input = inputs.at(1);
std::vector<size_t> bcast_strides(result.lens().size(), 0); std::vector<size_t> bcast_strides(result.lens().size(), 0);
if(std::all_of( if(std::all_of(
result.lens().cbegin(), result.lens().cend(), [&](auto x) { return x == 1; })) result.lens().cbegin(), result.lens().cend(), [&](auto x) { return x == 1; }))
{ {
......
...@@ -252,6 +252,63 @@ void gemm_test() ...@@ -252,6 +252,63 @@ void gemm_test()
} }
} }
void maxpool_test()
{
rtg::program p;
std::vector<float> a = {
-2.1314404, -1.63041711, 1.54562736, 1.04625261, -1.42931843, -0.48703974, 0.4065806,
-0.1524526, 1.30775225, 0.45538983, -0.06631992, -1.75332725, 1.33493888, 0.47327688,
0.36873096, 1.18358743, -0.34640595, 1.22098756, 0.01946825, -0.20238149, 0.43348005,
-0.67991608, -0.83041084, 0.93537551, 0.70241445, -0.5654031, -1.30899191, -0.26735824,
-0.52444768, 1.99097753, 1.86504853, -0.26506025, 0.26236168, 0.43763575, 0.95300823,
-1.02733946, -0.74655169, -0.5374338, -0.28901565, -0.59789604, 0.5310151, 0.99125904,
0.40609556, -1.57175648, 0.22031412, 1.45862222, 0.53217483, 1.39087725, 1.00170159,
-0.87175864, -1.7204628, -1.72008383, -0.38656762, -0.01443311, 1.46645272, -1.39995027,
0.22505587, -0.43461126, -0.05511411, -0.79950953, -0.01439556, 0.08795211, 1.18943918,
-0.84079367, -1.73383629, -0.55662078, -0.30626822, -0.67339015, 0.44179603, 0.54316711,
0.40899998, -0.27831686, -1.11900508, -0.0881724, 0.35483059, 2.36277103, -0.04765317,
-0.36865309, 0.73814237, 1.47151589, 1.36546791, -0.32649881, -1.0517807, 2.24768877,
0.68883753, 0.58646208, -0.91017133, -0.50462508, -0.4013325, -0.72348958, -0.47368807,
0.35285577, -1.01817429, -0.5152272, 0.60321307, 0.43521205, -0.23733577, 0.66427642,
0.82949388, 0.82443929, 0.71550399, 0.34561086, 0.68570769, -0.40718508, -1.20350206,
0.15793853, -2.31013632, -0.07934658, -0.09348056, 0.36576006, 2.46601582, 0.11090943,
0.9144392, 0.56759721, -0.22112127, -0.21955389, 0.72474903, -1.28448462, 1.53285873,
0.37437943, 0.31409341, 1.95433736, 0.91620457, 0.86205518, 1.24365854, 0.19248386,
0.22526583, 0.13462132, -0.27561715, -2.06446075, -0.02306402, -1.38278747, 1.1411345,
1.31293464, -1.86041689, 1.06763375, -0.26541466, 1.4545635, 1.11430049, -0.66491818,
0.87101674, 0.67768967, -1.02062869, -1.05031872, -2.2764678, -2.0200038, 0.37592548,
-0.26701379, -0.83388507, 0.19403623, 1.00968623, 0.11020003, 1.16736257, -1.1160326,
0.47346735, 0.6126079, -0.19135755, 1.33624589, -0.29802522, -0.57873946, -1.06555879,
-0.20686582, 1.36892557, -0.19937795, 0.8649236, -1.40126073, 1.53441942, 0.34682792,
-1.31724346, -1.32898355, 2.40126371, 0.07845283, 1.35732043, -0.63678312, 0.39429256,
-1.36487007, -0.31026676, -0.44981545, -0.28994772, -0.14657612, -1.75206447, -0.70612341,
1.20071781, -1.64647579, -0.7133292, 0.88494766, 0.52119428, -2.77387547, 2.07681108,
-0.90133125, 0.2847338, 0.6174528, -0.20616426, -0.64263535, -1.08496261, 0.54275119,
-0.88503587, 0.6629802, 1.47319221, -1.05829155, -0.97027361, -0.93187737, -1.39954746,
-0.52359426, -0.14743951, 1.51522756, 0.2078452, -1.28156149, -1.19363916, -0.78680223,
-0.89094824, 1.30212069, -0.77974445, -0.58411664, 0.48764706, -0.67132682};
std::vector<float> c = {1.33493888, 1.54562736, 1.22098756, 1.33493888, 1.18358743, 1.99097753,
1.00170159, 1.45862222, 1.39087725, 1.46645272, 1.18943918, -0.01443311,
1.47151589, 2.36277103, 2.24768877, 0.68883753, 0.82949388, 0.71550399,
1.95433736, 2.46601582, 1.53285873, 1.95433736, 1.06763375, 1.4545635,
1.33624589, 1.16736257, 0.6126079, 1.36892557, 2.40126371, 1.53441942,
0.52119428, 2.07681108, 0.88494766, 1.51522756, 0.54275119, 0.6629802};
rtg::shape a_shape{rtg::shape::float_type, {2, 3, 6, 6}};
auto al = p.add_literal(rtg::literal{a_shape, a});
p.add_instruction(rtg::pooling{"max", {{0, 0}}, {{2, 2}}, {{3, 2}}}, al);
p.compile(rtg::cpu::cpu_target{});
auto result = p.eval({});
std::cout << result.get_shape() << std::endl;
std::vector<float> results_vector(36);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
float tol = 1e-6;
for(int i = 0; i < results_vector.size(); i++)
{
// std::cout << results_vector[i] << " " << c[i] << std::endl;
EXPECT(std::abs(results_vector[i] - c[i]) < tol);
}
}
void softmax_test() void softmax_test()
{ {
rtg::program p; rtg::program p;
...@@ -564,6 +621,7 @@ int main() ...@@ -564,6 +621,7 @@ int main()
transpose_test(); transpose_test();
contiguous_test(); contiguous_test();
softmax_test(); softmax_test();
maxpool_test();
conv2d_test(); conv2d_test();
conv2d_padding_test(); conv2d_padding_test();
conv2d_padding_stride_test(); conv2d_padding_stride_test();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment