"...git@developer.sourcefind.cn:modelzoo/solov2-pytorch.git" did not exist on "f4da38a7336e77b39bda793ef4675cdd3e133947"
Unverified Commit 8ca7b140 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Bert squad eliminate contiguous (#567)



* code backup

* clang format

* refine the algorithm to support more scenarios

* clang format

* fix review comments

* clang format

* add one more unit tests to have more code change coverage
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent a5fb837d
...@@ -760,29 +760,42 @@ struct find_split_reshape ...@@ -760,29 +760,42 @@ struct find_split_reshape
} }
// ensure reshape happens after the axis dimension // ensure reshape happens after the axis dimension
auto axis = any_cast<op::slice>(slc->get_operator()).axes[0]; auto axis = any_cast<op::slice>(slc->get_operator()).axes[0];
auto in_lens = input->get_shape().lens(); auto slc_lens = slc->get_shape().lens();
if(!std::equal(in_lens.begin(), in_lens.begin() + axis, dims.begin(), dims.begin() + axis)) auto slc_dim_size = std::accumulate(
slc_lens.begin() + axis, slc_lens.end(), 1, std::multiplies<std::size_t>());
// search the reshape output (standard shape) to decide which axis are
// in its output corresponding to the slc_dim_size
auto rsp_lens = rsp->get_shape().lens();
auto rsp_strides = rsp->get_shape().strides();
rsp_strides.insert(rsp_strides.begin(), rsp_strides[0] * rsp_lens[0]);
auto ait = std::find(rsp_strides.begin(), rsp_strides.end(), slc_dim_size);
if(ait == rsp_strides.end())
{ {
return; return;
} }
int rsp_axis = std::distance(rsp_strides.begin(), ait);
// calculate reshape output shape // calculate reshape output shape
auto tmp_lens = vec_rsp.front()->get_shape().lens(); std::vector<int64_t> vec_dims(vec_rsp.size());
std::vector<int64_t> rsp_lens(tmp_lens.begin(), tmp_lens.end()); std::transform(vec_rsp.begin(), vec_rsp.end(), vec_dims.begin(), [&](auto is) {
int64_t dim_size = rsp_lens[axis]; return is->get_shape().lens()[rsp_axis];
rsp_lens[axis] *= vec_rsp.size(); });
std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end());
rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0});
// insert the reshape instruction // insert the reshape instruction
auto rsp_ins = p.insert_instruction(std::next(input), op::reshape{rsp_lens}, input); auto rsp_ins = p.insert_instruction(std::next(input), op::reshape{rsp_out_lens}, input);
// replace the original reshape with slice // replace the original reshape with slice
int64_t i = 0; int64_t start = 0;
for(auto in : vec_rsp) for(std::size_t i = 0; i < vec_rsp.size(); ++i)
{ {
p.replace_instruction( p.replace_instruction(
in, op::slice{{axis}, {i * dim_size}, {(i + 1) * dim_size}}, rsp_ins); vec_rsp[i], op::slice{{rsp_axis}, {start}, {start + vec_dims[i]}}, rsp_ins);
++i; start += vec_dims[i];
} }
} }
}; };
......
...@@ -1449,11 +1449,11 @@ TEST_CASE(reorder_reshape_slice) ...@@ -1449,11 +1449,11 @@ TEST_CASE(reorder_reshape_slice)
test(8); test(8);
} }
TEST_CASE(reorder_reshape_slice_invalid_axis) TEST_CASE(reorder_reshape_slice_move_axis1)
{ {
auto create_p1 = [](std::size_t batch_size) { auto create_p1 = [](std::size_t batch_size) {
migraphx::program p1; migraphx::program p1;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 129, 96}}; auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 256, 96}};
std::vector<int64_t> perm0 = {0, 2, 1, 3}; std::vector<int64_t> perm0 = {0, 2, 1, 3};
std::vector<int64_t> perm1 = {0, 2, 3, 1}; std::vector<int64_t> perm1 = {0, 2, 3, 1};
auto input = p1.add_parameter("input", s); auto input = p1.add_parameter("input", s);
...@@ -1465,7 +1465,7 @@ TEST_CASE(reorder_reshape_slice_invalid_axis) ...@@ -1465,7 +1465,7 @@ TEST_CASE(reorder_reshape_slice_invalid_axis)
auto c1 = p1.add_instruction(migraphx::op::contiguous{}, slc1); auto c1 = p1.add_instruction(migraphx::op::contiguous{}, slc1);
auto c2 = p1.add_instruction(migraphx::op::contiguous{}, slc2); auto c2 = p1.add_instruction(migraphx::op::contiguous{}, slc2);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 43, 3, 32}; std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 64, 4, 32};
auto r0 = p1.add_instruction(migraphx::op::reshape{lens}, c0); auto r0 = p1.add_instruction(migraphx::op::reshape{lens}, c0);
auto r1 = p1.add_instruction(migraphx::op::reshape{lens}, c1); auto r1 = p1.add_instruction(migraphx::op::reshape{lens}, c1);
auto r2 = p1.add_instruction(migraphx::op::reshape{lens}, c2); auto r2 = p1.add_instruction(migraphx::op::reshape{lens}, c2);
...@@ -1481,9 +1481,31 @@ TEST_CASE(reorder_reshape_slice_invalid_axis) ...@@ -1481,9 +1481,31 @@ TEST_CASE(reorder_reshape_slice_invalid_axis)
return p1; return p1;
}; };
auto create_p2 = [](std::size_t batch_size) {
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 256, 96}};
std::vector<int64_t> perm0 = {0, 2, 1, 3};
std::vector<int64_t> perm1 = {0, 2, 3, 1};
auto input = p.add_parameter("input", s);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 64, 4, 96};
auto rsp = p.add_instruction(migraphx::op::reshape{lens}, input);
auto slc0 = p.add_instruction(migraphx::op::slice{{3}, {0}, {32}}, rsp);
auto t0 = p.add_instruction(migraphx::op::transpose{perm0}, slc0);
auto slc1 = p.add_instruction(migraphx::op::slice{{3}, {32}, {64}}, rsp);
auto t1 = p.add_instruction(migraphx::op::transpose{perm0}, slc1);
auto slc2 = p.add_instruction(migraphx::op::slice{{3}, {64}, {96}}, rsp);
auto t2 = p.add_instruction(migraphx::op::transpose{perm1}, slc2);
auto sum = p.add_instruction(migraphx::op::add{}, t0, t1);
auto ret = p.add_instruction(migraphx::op::dot{}, sum, t2);
p.add_return({ret});
return p;
};
auto test = [&](std::size_t batch_size) { auto test = [&](std::size_t batch_size) {
auto p1 = create_p1(batch_size); auto p1 = create_p1(batch_size);
auto p2 = p1; auto p2 = create_p2(batch_size);
run_pass(p1); run_pass(p1);
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
}; };
...@@ -1492,6 +1514,87 @@ TEST_CASE(reorder_reshape_slice_invalid_axis) ...@@ -1492,6 +1514,87 @@ TEST_CASE(reorder_reshape_slice_invalid_axis)
test(8); test(8);
} }
TEST_CASE(reorder_reshape_slice_move_axis2)
{
auto create_p1 = [] {
migraphx::program p1;
migraphx::shape s{migraphx::shape::float_type, {128, 96}};
auto input = p1.add_parameter("input", s);
auto slc0 = p1.add_instruction(migraphx::op::slice{{1}, {0}, {32}}, input);
auto slc1 = p1.add_instruction(migraphx::op::slice{{1}, {32}, {64}}, input);
auto slc2 = p1.add_instruction(migraphx::op::slice{{1}, {64}, {96}}, input);
auto c0 = p1.add_instruction(migraphx::op::contiguous{}, slc0);
auto c1 = p1.add_instruction(migraphx::op::contiguous{}, slc1);
auto c2 = p1.add_instruction(migraphx::op::contiguous{}, slc2);
std::vector<int64_t> lens = {1, 16, 8, 32};
auto r0 = p1.add_instruction(migraphx::op::reshape{lens}, c0);
auto r1 = p1.add_instruction(migraphx::op::reshape{lens}, c1);
auto r2 = p1.add_instruction(migraphx::op::reshape{lens}, c2);
auto sum = p1.add_instruction(migraphx::op::add{}, r0, r1);
auto ret = p1.add_instruction(migraphx::op::mul{}, sum, r2);
p1.add_return({ret});
return p1;
};
auto create_p2 = [] {
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {128, 96}};
auto input = p.add_parameter("input", s);
std::vector<int64_t> lens = {1, 16, 8, 96};
auto rsp = p.add_instruction(migraphx::op::reshape{lens}, input);
auto slc0 = p.add_instruction(migraphx::op::slice{{3}, {0}, {32}}, rsp);
auto slc1 = p.add_instruction(migraphx::op::slice{{3}, {32}, {64}}, rsp);
auto slc2 = p.add_instruction(migraphx::op::slice{{3}, {64}, {96}}, rsp);
auto sum = p.add_instruction(migraphx::op::add{}, slc0, slc1);
auto ret = p.add_instruction(migraphx::op::mul{}, sum, slc2);
p.add_return({ret});
return p;
};
auto p1 = create_p1();
auto p2 = create_p2();
run_pass(p1);
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(reorder_reshape_slice_not_apply)
{
auto create_p = [] {
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {128, 96}};
auto input = p.add_parameter("input", s);
auto slc0 = p.add_instruction(migraphx::op::slice{{1}, {0}, {32}}, input);
auto slc1 = p.add_instruction(migraphx::op::slice{{1}, {32}, {64}}, input);
auto slc2 = p.add_instruction(migraphx::op::slice{{1}, {64}, {96}}, input);
auto c0 = p.add_instruction(migraphx::op::contiguous{}, slc0);
auto c1 = p.add_instruction(migraphx::op::contiguous{}, slc1);
auto c2 = p.add_instruction(migraphx::op::contiguous{}, slc2);
std::vector<int64_t> lens = {1, 16, 16, 16};
auto r0 = p.add_instruction(migraphx::op::reshape{lens}, c0);
auto r1 = p.add_instruction(migraphx::op::reshape{lens}, c1);
auto r2 = p.add_instruction(migraphx::op::reshape{lens}, c2);
auto sum = p.add_instruction(migraphx::op::add{}, r0, r1);
auto ret = p.add_instruction(migraphx::op::mul{}, sum, r2);
p.add_return({ret});
return p;
};
auto p1 = create_p();
auto p2 = p1;
run_pass(p1);
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(reorder_reshape_slice_diff_dims) TEST_CASE(reorder_reshape_slice_diff_dims)
{ {
auto create_p1 = [](std::size_t batch_size) { auto create_p1 = [](std::size_t batch_size) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment