Unverified Commit 4b76dd0d authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Fix split_reshape for slice len of 1 (#1379)

* fix slice_dim1 for case
parent 7662d9c0
......@@ -985,20 +985,35 @@ struct find_split_reshape
auto rsp_lens = rsp->get_shape().lens();
auto rsp_strides = rsp->get_shape().strides();
rsp_strides.insert(rsp_strides.begin(), rsp_strides[0] * rsp_lens[0]);
auto ait = std::find(rsp_strides.begin(), rsp_strides.end(), slc_dim_size);
int rsp_axis = -1;
if(ait == rsp_strides.end())
{
return;
}
int rsp_axis = std::distance(rsp_strides.begin(), ait);
else if(ait == rsp_strides.end() - 1)
{
// edge case
// slice_dim == 1, in that case it could match with last stride of 1.
// it should accumulate lengths from last dim in that case. discount 1 to avoid going
// out of bounds.
assert(slc_dim_size == 1);
rsp_axis = std::distance(rsp_strides.begin(), ait) - 1;
}
else
{
rsp_axis = std::distance(rsp_strides.begin(), ait);
}
// calculate reshape output shape
std::vector<int64_t> vec_dims(vec_rsp.size());
std::transform(vec_rsp.begin(), vec_rsp.end(), vec_dims.begin(), [&](auto is) {
return is->get_shape().lens()[rsp_axis];
});
std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end());
rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0});
// insert the reshape instruction and add contiguous if needed
......
......@@ -2077,6 +2077,55 @@ TEST_CASE(reorder_reshape_slice_move_axis2)
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(reorder_reshape_slice_len_1)
{
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {1, 128, 3}};
auto input = m1.add_parameter("input", s);
auto slc0 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {1}}}), input);
auto slc1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1}}, {"ends", {2}}}), input);
auto slc2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {3}}}), input);
auto c0 = m1.add_instruction(migraphx::make_op("contiguous"), slc0);
auto c1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1);
auto c2 = m1.add_instruction(migraphx::make_op("contiguous"), slc2);
std::vector<int64_t> lens = {1, 128};
auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
auto sum = m1.add_instruction(migraphx::make_op("add"), r0, r1);
auto ret = m1.add_instruction(migraphx::make_op("mul"), sum, r2);
m1.add_return({ret});
};
migraphx::module m2;
{
auto s = migraphx::shape{migraphx::shape::float_type, {1, 128, 3}};
auto input = m2.add_parameter("input", s);
std::vector<int64_t> lens = {1, 384};
auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input);
auto slc0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {128}}}), rsp);
auto slc1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {128}}, {"ends", {256}}}), rsp);
auto slc2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {256}}, {"ends", {384}}}), rsp);
auto sum = m2.add_instruction(migraphx::make_op("add"), slc0, slc1);
auto ret = m2.add_instruction(migraphx::make_op("mul"), sum, slc2);
m2.add_return({ret});
};
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(reorder_reshape_slice_not_apply)
{
auto create_p = [] {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment