"examples/asr/vscode:/vscode.git/clone" did not exist on "8f68b3f0ba3d7c58661b63c4c4c8656e18aa1ddb"
Unverified Commit c89c90db authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Pooling_nd_cpu_implementation (#548)



* initial progress

* formatting

* add pooling changes

* formatting

* change eliminate_pad

* formatting

* rename var

* fomratting

* update op shape test and compute

* formatting

* revert conv constructor

* formatting

* change initializer

* formatting

* fix tidy

* change quant conv and shape check

* add tests and fixes

* formatting

* fix type

* fix conv test

* formatting

* add pooling and bn tests

* formatting

* add inconsistent attr tests

* fix padding issue

* formatting

* fix review comments, remove duplicate test

* formatting

* fix variable

* fix assert bug

* fix attr check

* remove std

* nd pooling cpu implementation

* add unit test for 1d and 3d pooling operator

* add more unit test for avareage pooling

* add pooling unit tests for cpu implementation

* clang format

* fix cppcheck error

* clang format
Co-authored-by: default avatarKhalique <15948690+kahmed10@users.noreply.github.com>
parent 39dca6ac
......@@ -388,40 +388,48 @@ struct cpu_pooling
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
using type = typename decltype(output)::value_type;
auto in_h = input.get_shape().lens()[2];
auto in_w = input.get_shape().lens()[3];
par_dfor(output_shape.lens()[0],
output_shape.lens()[1],
output_shape.lens()[2],
output_shape.lens()[3])(
[&](std::size_t o, std::size_t w, std::size_t i, std::size_t j) {
const int start_x0 = i * op.stride[0] - op.padding[0];
const int start_y0 = j * op.stride[1] - op.padding[1];
const int hend = std::min(start_x0 + op.lengths[0], in_h);
const int wend = std::min(start_y0 + op.lengths[1], in_w);
const int start_x = std::max(start_x0, 0);
const int start_y = std::max(start_y0, 0);
const int w_h = (hend - start_x);
const int w_w = (wend - start_y);
const int pool_size = std::max(w_h * w_w, 1);
double acc = Op::start();
dfor(w_h, w_w)([&](int x, int y) {
const int in_x = start_x + x;
const int in_y = start_y + y;
if(in_x >= 0 && in_x < in_h && in_y >= 0 && in_y < in_w)
{
acc = Op::apply(acc, input(o, w, in_x, in_y));
}
});
output(o, w, i, j) = type(Op::final(acc, pool_size));
using type = typename decltype(output)::value_type;
auto in_s = input.get_shape();
auto in_lens = in_s.lens();
std::vector<std::size_t> vec_len(in_lens.begin() + 2, in_lens.end());
par_for(output_shape.elements(), [&](auto i) {
auto idx_o = output_shape.multi(i);
auto n_dim = idx_o.size();
std::vector<std::size_t> win_start;
std::vector<std::size_t> win_size;
for(std::size_t dim = 2; dim < n_dim; ++dim)
{
auto d_2 = dim - 2;
int start = static_cast<int>(idx_o[dim] * op.stride[d_2]) -
static_cast<int>(op.padding[d_2]);
int end = std::min(start + op.lengths[d_2], in_lens[dim]);
start = std::max(start, 0);
win_start.push_back(start);
win_size.push_back(end - start);
}
shape win_shape{output_shape.type(), win_size};
auto pool_size = win_shape.elements();
double acc = Op::start();
shape_for_each(win_shape, [&](auto idx_w) {
auto idx = idx_o;
std::transform(idx_w.begin(),
idx_w.end(),
win_start.begin(),
idx.begin() + 2,
[](auto ii, auto jj) { return ii + jj; });
if(std::all_of(idx.begin() + 2, idx.end(), [&](auto ii) { return ii >= 0; }) and
idx < in_lens)
{
acc = Op::apply(acc, input[in_s.index(idx)]);
}
});
output[i] = type(Op::final(acc, pool_size));
});
});
return result;
}
};
......
......@@ -364,6 +364,190 @@ TEST_CASE(unsqueeze_test)
}
}
TEST_CASE(avgpool_test)
{
// 1D case 1, input is 3D
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{"average"};
op.lengths = {2};
op.padding = {0};
op.stride = {1};
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = p.add_literal(migraphx::literal{s, data});
p.add_instruction(op, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.25, 0.3, 0.25, 0.65, 0.7, 0.5, 0.4, 0.4, 0.35};
EXPECT(migraphx::verify_range(results_vector, gold));
}
// 1D case 2, stride 2
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 4}};
auto op = migraphx::op::pooling{"average"};
op.lengths = {2};
op.padding = {1};
op.stride = {2};
std::vector<float> data{1.6321,
-2.4186,
0.2239,
-1.4232,
0.8158,
0.4103,
-0.3149,
-0.1361,
-0.3442,
2.007,
0.4331,
1.5295,
0.9965,
0.4766,
1.0942,
-0.2915};
auto l0 = p.add_literal(migraphx::literal{s, data});
p.add_instruction(op, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({}).back();
std::cout << "result = " << result << std::endl;
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.6321,
-1.0974,
-1.4232,
0.8158,
0.0477,
-0.1361,
-0.3442,
1.22005,
1.5295,
0.9965,
0.7854,
-0.2915};
EXPECT(migraphx::verify_range(results_vector, gold));
}
// 3D, input is 5D
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 3, 3}};
auto op = migraphx::op::pooling{"average"};
op.lengths = {2, 2, 2};
op.padding = {0, 0, 0};
op.stride = {1, 1, 1};
std::vector<float> data{
-0.179, -1.756, 0.651, 1.955, 1.87, -0.604, 0.247, 0.449, -0.137, 1.187, 1.593,
0.424, 2.698, -0.104, -0.069, -1.293, 0.538, 1.291, 0.974, 1.096, 0.74, -0.669,
-1.08, -1.041, -1.407, 1.43, -0.211, -0.017, 0.532, 1.276, 0.627, 0.236, -0.396,
-0.204, 0.501, -0.599, -1.414, -0.615, -0.274, 0.168, -0.144, 0.5, 1.42, 1.082,
-0.952, -0.846, -1.244, 1.475, 1.246, 1.344, -1.722, -1.24, -0.851, 0.06, 0.507,
0.762, -0.007, -1.484, 1.028, 0.317, 1.077, -1.289, 0.875, -0.417, -0.673, 1.715,
-0.307, 0.264, -0.973, 1.412, 2.561, -0.515, -0.201, 0.827, -1.231, 1.958, -0.552,
0.036, -0.993, -0.859, -1.458, -0.575, 0.048, -0.779, -1.025, -1.135, 1.166, -0.131,
0.726, 0.52, 0.467, -0.494, 0.675, 0.203, -0.63, -0.918, -0.5, -1.395, 1.39,
1.705, 0.444, -0.835, -0.506, 0.101, 0.602, 0.543, 0.357, 1.042};
auto l0 = p.add_literal(migraphx::literal{s, data});
p.add_instruction(op, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{
0.908, 0.250625, 0.795, 0.40425, 0.711875, 0.194875, 0.014125, 0.09425,
-0.078375, 0.139375, 0.46075, 0.0285, -0.188125, -0.085, 0.378125, -0.085375,
-0.04, 0.304125, 0.40775, 0.2835, 0.112375, -0.073375, 0.4355, -0.187,
-0.392625, -0.258375, -0.485875, -0.0345, 0.16125, -0.131875, -0.228375, 0.068625};
EXPECT(migraphx::verify_range(results_vector, gold));
}
}
TEST_CASE(maxpool_test_1D_3D)
{
// 1D case 1, input is 3D
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}};
auto op = migraphx::op::pooling{"max"};
op.lengths = {2};
op.padding = {0};
op.stride = {1};
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = p.add_literal(migraphx::literal{s, data});
p.add_instruction(op, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.3, 0.4, 0.4, 0.8, 0.9, 0.9, 0.7, 0.7, 0.6};
EXPECT(migraphx::verify_range(results_vector, gold));
}
// 1D case 2, input is 3D
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 5}};
auto op = migraphx::op::pooling{"max"};
op.lengths = {2};
op.padding = {0};
op.stride = {2};
std::vector<float> data{0.4975, -0.1226, -0.0405, -0.2861, -0.1227, -0.6186, -0.9618,
0.6022, -0.1912, 1.1925, 0.5493, 0.1692, -0.8039, -1.0281,
0.9907, 0.477, 1.5001, -1.1603, -1.361, 1.2556};
auto l0 = p.add_literal(migraphx::literal{s, data});
p.add_instruction(op, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.4975, -0.0405, -0.6186, 0.6022, 0.5493, -0.8039, 1.5001, -1.1603};
EXPECT(migraphx::verify_range(results_vector, gold));
}
// 3D, input is 5D
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 3, 3}};
auto op = migraphx::op::pooling{"max"};
op.lengths = {2, 2, 2};
op.padding = {0, 0, 0};
op.stride = {2, 2, 2};
std::vector<float> data{
-2.8029, 0.5861, 0.7015, 0.1297, -1.44, -1.9472, 0.7812, 2.408, -0.3145,
0.3405, -0.9146, 0.0624, 1.5064, -0.8345, 1.7977, 1.8949, 1.0073, -0.2102,
-0.042, -0.7146, 0.6227, -0.5263, -2.2598, 0.1713, 0.449, 0.5303, -0.8622,
-0.5691, 0.907, -0.0569, -1.5348, -0.4109, -0.1461, -0.5445, 0.4266, 0.2282,
1.3655, -2.1519, 0.6068, -0.2001, -0.4702, 0.3864, 1.7083, 0.9096, 0.4286,
-1.8866, 0.7034, 0.0293, 1.4587, 0.7672, -2.8614, 0.8124, -0.053, 1.0449,
0.845, -0.0131, 0.1139, -0.859, -1.2681, -0.6337, -0.4644, 0.1938, 0.2889,
0.9035, 0.7118, -0.5767, 0.4577, -0.0549, 0.2237, 0.5756, 0.0677, -0.0223,
-0.329, 0.2364, 2.7666, -0.7417, -1.3196, -0.2655, 0.1698, -0.1777, -0.9427,
2.6859, -0.7501, 0.5175, 1.0029, -2.6436, -0.4388, -1.2348, -0.1539, -0.6229,
-0.4136, 0.5085, 0.4136, -0.6439, -1.1953, -0.406, -0.0195, 0.1869, -0.8664,
1.1364, 0.5041, 0.0647, 0.1941, -1.0819, -0.4629, -0.5107, 0.3612, -0.3583};
auto l0 = p.add_literal(migraphx::literal{s, data});
p.add_instruction(op, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.5064, 1.3655, 0.9035, 2.6859};
EXPECT(migraphx::verify_range(results_vector, gold));
}
}
TEST_CASE(globalavgpool_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment