Unverified Commit 2c8149f6 authored by shivadbhavsar's avatar shivadbhavsar Committed by GitHub
Browse files

Transpose slice fix (#1499)

This PR resolves the bug addressed in #1496. 
parent 1eb5a1d4
...@@ -1065,11 +1065,23 @@ struct find_split_reshape ...@@ -1065,11 +1065,23 @@ struct find_split_reshape
return; return;
} }
// Only want to apply this optimization if each split output is followed by
// a contiguous op and a reshape
if(std::any_of(split_outputs.begin(), split_outputs.end(), [](auto i) {
if(i->outputs().size() == 1)
{
auto cont = i->outputs().front();
return cont->outputs().size() != 1;
}
return false;
}))
{
return;
}
std::vector<instruction_ref> vec_rsp(split_outputs.size()); std::vector<instruction_ref> vec_rsp(split_outputs.size());
std::transform(split_outputs.begin(), split_outputs.end(), vec_rsp.begin(), [](auto i) { std::transform(split_outputs.begin(), split_outputs.end(), vec_rsp.begin(), [](auto i) {
assert(i->outputs().size() == 1);
auto cont = i->outputs().front(); auto cont = i->outputs().front();
assert(cont->outputs().size() == 1);
return cont->outputs().front(); return cont->outputs().front();
}); });
......
...@@ -763,16 +763,23 @@ struct find_transpose_slice ...@@ -763,16 +763,23 @@ struct find_transpose_slice
// Compute axis before transpose to use for unsqueeze // Compute axis before transpose to use for unsqueeze
auto perm = ins->get_operator().to_value()["permutation"].to_vector<int64_t>(); auto perm = ins->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto preaxis = std::find(perm.begin(), perm.end(), axis) - perm.begin(); auto preaxis = std::find(perm.begin(), perm.end(), axis) - perm.begin();
// Make unsqeeze // Make unsqueeze
std::vector<int64_t> steps(sdistance.size());
std::transform(
slice.axes.begin(),
slice.axes.end(),
sdistance.begin(),
steps.begin(),
[&](const auto ax, const auto sdis) { return ins->get_shape().lens().at(ax) / sdis; });
auto unsqueeze = m.insert_instruction( auto unsqueeze = m.insert_instruction(
ins, make_op("unsqueeze", {{"axes", {preaxis}}, {"steps", sdistance}}), ins->inputs()); ins, make_op("unsqueeze", {{"axes", {preaxis}}, {"steps", steps}}), ins->inputs());
// Make transpose // Make transpose
std::transform(perm.begin(), perm.end(), perm.begin(), [&](auto i) { std::transform(perm.begin(), perm.end(), perm.begin(), [&](auto i) {
if(i > preaxis) if(i >= preaxis)
return i + 1; return i + 1;
return i; return i;
}); });
perm.insert(perm.begin(), preaxis + 1); perm.insert(perm.begin(), preaxis);
auto transpose = auto transpose =
m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), unsqueeze); m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), unsqueeze);
// Slice and squeeze // Slice and squeeze
......
...@@ -2919,4 +2919,53 @@ TEST_CASE(reorder_slice_ins_deps) ...@@ -2919,4 +2919,53 @@ TEST_CASE(reorder_slice_ins_deps)
EXPECT(m == create_module()); EXPECT(m == create_module());
} }
TEST_CASE(dot_fusion_reshape)
{
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {2, 4096, 320}};
auto input = m1.add_parameter("input", s);
auto p0 = m1.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {2, 320, 320}}, 0));
auto p1 = m1.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {2, 320, 320}}, 1));
auto d0 = m1.add_instruction(migraphx::make_op("dot"), input, p0);
auto d1 = m1.add_instruction(migraphx::make_op("dot"), input, p1);
auto r0 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 4096, 8, 40}}}), d0);
m1.add_return({r0, d1});
};
migraphx::module m2;
{
migraphx::shape s{migraphx::shape::float_type, {2, 4096, 320}};
auto input = m2.add_parameter("input", s);
auto p0 = m2.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {2, 320, 320}}, 0));
auto p1 = m2.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {2, 320, 320}}, 1));
auto c = m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), p0, p1);
auto d = m2.add_instruction(migraphx::make_op("dot"), input, c);
auto s0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {320}}}), d);
auto s1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {320}}, {"ends", {640}}}), d);
auto cont0 = m2.add_instruction(migraphx::make_op("contiguous"), s0);
auto r0 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 4096, 8, 40}}}), cont0);
m2.add_return({r0, s1});
};
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -1405,9 +1405,9 @@ TEST_CASE(transpose_slice_non_packed_axis) ...@@ -1405,9 +1405,9 @@ TEST_CASE(transpose_slice_non_packed_axis)
{ {
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}}); auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto unsqueeze = auto unsqueeze =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {12}}}), x); m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {3}}}), x);
auto transpose = m2.add_instruction( auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {3, 0, 2, 1, 4}}}), unsqueeze); migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}), unsqueeze);
auto slice = m2.add_instruction( auto slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transpose); migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transpose);
auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice); auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice);
...@@ -1444,9 +1444,9 @@ TEST_CASE(transpose_slice_non_packed_multi_axis) ...@@ -1444,9 +1444,9 @@ TEST_CASE(transpose_slice_non_packed_multi_axis)
{ {
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}}); auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto unsqueeze = auto unsqueeze =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {12}}}), x); m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {3}}}), x);
auto transpose = m2.add_instruction( auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {3, 0, 2, 1, 4}}}), unsqueeze); migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}), unsqueeze);
auto slice1 = m2.add_instruction( auto slice1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transpose); migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transpose);
auto squeeze1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice1); auto squeeze1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice1);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_trans_slice : verify_program<test_trans_slice>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto transpose =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), x);
auto slice1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}),
transpose);
auto slice2 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {12}}, {"ends", {24}}}),
transpose);
auto transpose2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), slice2);
auto slice3 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {24}}, {"ends", {36}}}),
transpose);
mm->add_return({slice1, transpose2, slice3});
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment