Unverified Commit 8ca7b140 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Bert squad eliminate contiguous (#567)



* code backup

* clang format

* refine the algorithm to support more scenarios

* clang format

* fix review comments

* clang format

* add one more unit tests to have more code change coverage
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent a5fb837d
......@@ -760,29 +760,42 @@ struct find_split_reshape
}
// ensure reshape happens after the axis dimension
auto axis = any_cast<op::slice>(slc->get_operator()).axes[0];
auto in_lens = input->get_shape().lens();
if(!std::equal(in_lens.begin(), in_lens.begin() + axis, dims.begin(), dims.begin() + axis))
auto axis = any_cast<op::slice>(slc->get_operator()).axes[0];
auto slc_lens = slc->get_shape().lens();
auto slc_dim_size = std::accumulate(
slc_lens.begin() + axis, slc_lens.end(), 1, std::multiplies<std::size_t>());
// search the reshape output (standard shape) to decide which axis are
// in its output corresponding to the slc_dim_size
auto rsp_lens = rsp->get_shape().lens();
auto rsp_strides = rsp->get_shape().strides();
rsp_strides.insert(rsp_strides.begin(), rsp_strides[0] * rsp_lens[0]);
auto ait = std::find(rsp_strides.begin(), rsp_strides.end(), slc_dim_size);
if(ait == rsp_strides.end())
{
return;
}
int rsp_axis = std::distance(rsp_strides.begin(), ait);
// calculate reshape output shape
auto tmp_lens = vec_rsp.front()->get_shape().lens();
std::vector<int64_t> rsp_lens(tmp_lens.begin(), tmp_lens.end());
int64_t dim_size = rsp_lens[axis];
rsp_lens[axis] *= vec_rsp.size();
std::vector<int64_t> vec_dims(vec_rsp.size());
std::transform(vec_rsp.begin(), vec_rsp.end(), vec_dims.begin(), [&](auto is) {
return is->get_shape().lens()[rsp_axis];
});
std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end());
rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0});
// insert the reshape instruction
auto rsp_ins = p.insert_instruction(std::next(input), op::reshape{rsp_lens}, input);
auto rsp_ins = p.insert_instruction(std::next(input), op::reshape{rsp_out_lens}, input);
// replace the original reshape with slice
int64_t i = 0;
for(auto in : vec_rsp)
int64_t start = 0;
for(std::size_t i = 0; i < vec_rsp.size(); ++i)
{
p.replace_instruction(
in, op::slice{{axis}, {i * dim_size}, {(i + 1) * dim_size}}, rsp_ins);
++i;
vec_rsp[i], op::slice{{rsp_axis}, {start}, {start + vec_dims[i]}}, rsp_ins);
start += vec_dims[i];
}
}
};
......
......@@ -1449,11 +1449,11 @@ TEST_CASE(reorder_reshape_slice)
test(8);
}
TEST_CASE(reorder_reshape_slice_invalid_axis)
TEST_CASE(reorder_reshape_slice_move_axis1)
{
auto create_p1 = [](std::size_t batch_size) {
migraphx::program p1;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 129, 96}};
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 256, 96}};
std::vector<int64_t> perm0 = {0, 2, 1, 3};
std::vector<int64_t> perm1 = {0, 2, 3, 1};
auto input = p1.add_parameter("input", s);
......@@ -1465,7 +1465,7 @@ TEST_CASE(reorder_reshape_slice_invalid_axis)
auto c1 = p1.add_instruction(migraphx::op::contiguous{}, slc1);
auto c2 = p1.add_instruction(migraphx::op::contiguous{}, slc2);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 43, 3, 32};
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 64, 4, 32};
auto r0 = p1.add_instruction(migraphx::op::reshape{lens}, c0);
auto r1 = p1.add_instruction(migraphx::op::reshape{lens}, c1);
auto r2 = p1.add_instruction(migraphx::op::reshape{lens}, c2);
......@@ -1481,9 +1481,31 @@ TEST_CASE(reorder_reshape_slice_invalid_axis)
return p1;
};
auto create_p2 = [](std::size_t batch_size) {
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 256, 96}};
std::vector<int64_t> perm0 = {0, 2, 1, 3};
std::vector<int64_t> perm1 = {0, 2, 3, 1};
auto input = p.add_parameter("input", s);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 64, 4, 96};
auto rsp = p.add_instruction(migraphx::op::reshape{lens}, input);
auto slc0 = p.add_instruction(migraphx::op::slice{{3}, {0}, {32}}, rsp);
auto t0 = p.add_instruction(migraphx::op::transpose{perm0}, slc0);
auto slc1 = p.add_instruction(migraphx::op::slice{{3}, {32}, {64}}, rsp);
auto t1 = p.add_instruction(migraphx::op::transpose{perm0}, slc1);
auto slc2 = p.add_instruction(migraphx::op::slice{{3}, {64}, {96}}, rsp);
auto t2 = p.add_instruction(migraphx::op::transpose{perm1}, slc2);
auto sum = p.add_instruction(migraphx::op::add{}, t0, t1);
auto ret = p.add_instruction(migraphx::op::dot{}, sum, t2);
p.add_return({ret});
return p;
};
auto test = [&](std::size_t batch_size) {
auto p1 = create_p1(batch_size);
auto p2 = p1;
auto p2 = create_p2(batch_size);
run_pass(p1);
EXPECT(p1.sort() == p2.sort());
};
......@@ -1492,6 +1514,87 @@ TEST_CASE(reorder_reshape_slice_invalid_axis)
test(8);
}
TEST_CASE(reorder_reshape_slice_move_axis2)
{
auto create_p1 = [] {
migraphx::program p1;
migraphx::shape s{migraphx::shape::float_type, {128, 96}};
auto input = p1.add_parameter("input", s);
auto slc0 = p1.add_instruction(migraphx::op::slice{{1}, {0}, {32}}, input);
auto slc1 = p1.add_instruction(migraphx::op::slice{{1}, {32}, {64}}, input);
auto slc2 = p1.add_instruction(migraphx::op::slice{{1}, {64}, {96}}, input);
auto c0 = p1.add_instruction(migraphx::op::contiguous{}, slc0);
auto c1 = p1.add_instruction(migraphx::op::contiguous{}, slc1);
auto c2 = p1.add_instruction(migraphx::op::contiguous{}, slc2);
std::vector<int64_t> lens = {1, 16, 8, 32};
auto r0 = p1.add_instruction(migraphx::op::reshape{lens}, c0);
auto r1 = p1.add_instruction(migraphx::op::reshape{lens}, c1);
auto r2 = p1.add_instruction(migraphx::op::reshape{lens}, c2);
auto sum = p1.add_instruction(migraphx::op::add{}, r0, r1);
auto ret = p1.add_instruction(migraphx::op::mul{}, sum, r2);
p1.add_return({ret});
return p1;
};
auto create_p2 = [] {
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {128, 96}};
auto input = p.add_parameter("input", s);
std::vector<int64_t> lens = {1, 16, 8, 96};
auto rsp = p.add_instruction(migraphx::op::reshape{lens}, input);
auto slc0 = p.add_instruction(migraphx::op::slice{{3}, {0}, {32}}, rsp);
auto slc1 = p.add_instruction(migraphx::op::slice{{3}, {32}, {64}}, rsp);
auto slc2 = p.add_instruction(migraphx::op::slice{{3}, {64}, {96}}, rsp);
auto sum = p.add_instruction(migraphx::op::add{}, slc0, slc1);
auto ret = p.add_instruction(migraphx::op::mul{}, sum, slc2);
p.add_return({ret});
return p;
};
auto p1 = create_p1();
auto p2 = create_p2();
run_pass(p1);
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(reorder_reshape_slice_not_apply)
{
auto create_p = [] {
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {128, 96}};
auto input = p.add_parameter("input", s);
auto slc0 = p.add_instruction(migraphx::op::slice{{1}, {0}, {32}}, input);
auto slc1 = p.add_instruction(migraphx::op::slice{{1}, {32}, {64}}, input);
auto slc2 = p.add_instruction(migraphx::op::slice{{1}, {64}, {96}}, input);
auto c0 = p.add_instruction(migraphx::op::contiguous{}, slc0);
auto c1 = p.add_instruction(migraphx::op::contiguous{}, slc1);
auto c2 = p.add_instruction(migraphx::op::contiguous{}, slc2);
std::vector<int64_t> lens = {1, 16, 16, 16};
auto r0 = p.add_instruction(migraphx::op::reshape{lens}, c0);
auto r1 = p.add_instruction(migraphx::op::reshape{lens}, c1);
auto r2 = p.add_instruction(migraphx::op::reshape{lens}, c2);
auto sum = p.add_instruction(migraphx::op::add{}, r0, r1);
auto ret = p.add_instruction(migraphx::op::mul{}, sum, r2);
p.add_return({ret});
return p;
};
auto p1 = create_p();
auto p2 = p1;
run_pass(p1);
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(reorder_reshape_slice_diff_dims)
{
auto create_p1 = [](std::size_t batch_size) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment