Unverified Commit 93be5e2b authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Bert fuse slice reshape trans contiguous (#542)



* fix pad calc

* Add decompose pass

* Add decompose test

* Formatting

* bert tf passes correctness

* formatting

* Add remap

* Formatting

* add test

* formatting

* remove comment

* Add compute method for dot

* Formatting

* add inline

* Add finder for horizontal fusion

* Formatting

* Formatting

* Reuse predicate

* formatting

* fix order for literal

* formatting

* add test for gelu

* formatting

* added add_gelu fusion

* Add gemm fusions

* Formatting

* add files

* formatting

* test no mul_add

* formatting

* progress on div

* formatting

* continue work on pass

* remove layernorm opt

* revert reduce file

* Add some fixes for convolution

* Formatting

* Fix shape tests

* Formatting

* Reuse axis equal

* Add initial split fusion

* Formatting

* Update offset

* Workaround outputs that cant accept nonstandard shapes

* Formatting

* Add check for split concat

* Formatting

* Add missing headers

* Formatting

* Add tests

* Formatting

* add optimization for bert

* code backup for bert optimization

* continue testing

* formatting

* fix matcher

* formatting

* add gelu_fn and tests

* formatting

* fix matcher, remove extra tests

* formatting

* fix matcher

* add missing files

* add find_layernorm

* add add_transpose to cmake file

* code backup for the contigous fusion

* refine ops fusion

* clang format

* fixed bug in previous optimization

* clang format

* add more optimization

* remove unnecessary code

* refinement of the fustion code

* clang format

* fixed a bug

* add used_once

* formatting

* start on new gelu

* formatting

* add matchers in fuse_ops

* formatting

* add dce to fix add_gelu

* add simplify_rsqrt and test

* formatting

* debugging value for matcher

* formatting

* add more to matchers

* formatting

* fix errors

* remove onnx gen

* add any_arg, change matchers to use either_arg

* formatting

* clang format

* formatting

* add used_once

* formatting

* code cleanup

* clang format

* fixed a bug

* remove unnecessary code

* refine comments

* optimize bert to remove more contiguous

* clang format

* remove unnecessary code

* add unit tests for bert optimization

* clang format

* fix review comments

* clang format

* refine a fusion of reshape and slice

* clang format

* fix cppcheck error

* fix review comments

* add the fusion of slice and transpose

* clang format

* add another optimization to fuse slice and transpose

* clang format

* fix review comments

* clang format

* fix review comments

* clang format

* fix review comments
Co-authored-by: default avatarKhalique <15948690+kahmed10@users.noreply.github.com>
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
Co-authored-by: default avatarShucai Xiao <scxiao@prj47-rack-99.local.lan>
parent 7f553e51
...@@ -11,7 +11,9 @@ ...@@ -11,7 +11,9 @@
#include <migraphx/op/broadcast.hpp> #include <migraphx/op/broadcast.hpp>
#include <migraphx/op/neg.hpp> #include <migraphx/op/neg.hpp>
#include <migraphx/op/recip.hpp> #include <migraphx/op/recip.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/rsqrt.hpp> #include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/algorithm.hpp> #include <migraphx/algorithm.hpp>
...@@ -714,6 +716,130 @@ struct find_rsqrt ...@@ -714,6 +716,130 @@ struct find_rsqrt
} }
}; };
static bool same_ops(const std::vector<instruction_ref>& vec_ins)
{
return std::all_of(vec_ins.begin(), vec_ins.end(), [&](auto i) {
return i->get_operator() == vec_ins.front()->get_operator();
});
}
struct find_split_reshape
{
auto matcher() const
{
return match::name("reshape")(match::arg(0)(match::name("contiguous")(
match::arg(0)(match::name("slice").bind("slice")))))
.bind("reshape");
}
void apply(program& p, match::matcher_result r) const
{
auto slc = r.instructions["slice"];
auto rsp = r.instructions["reshape"];
auto input = slc->inputs().front();
auto split_outputs = get_splits(input);
if(split_outputs.empty())
{
return;
}
std::vector<instruction_ref> vec_rsp(split_outputs.size());
std::transform(split_outputs.begin(), split_outputs.end(), vec_rsp.begin(), [](auto i) {
assert(i->outputs().size() == 1);
auto cont = i->outputs().front();
assert(cont->outputs().size() == 1);
return cont->outputs().front();
});
// all outputs are reshape and of the same shape
auto dims = any_cast<op::reshape>(rsp->get_operator()).dims;
if(!same_ops(vec_rsp))
{
return;
}
// ensure reshape happens after the axis dimension
auto axis = any_cast<op::slice>(slc->get_operator()).axes[0];
auto in_lens = input->get_shape().lens();
if(!std::equal(in_lens.begin(), in_lens.begin() + axis, dims.begin(), dims.begin() + axis))
{
return;
}
// calculate reshape output shape
auto tmp_lens = vec_rsp.front()->get_shape().lens();
std::vector<int64_t> rsp_lens(tmp_lens.begin(), tmp_lens.end());
int64_t dim_size = rsp_lens[axis];
rsp_lens[axis] *= vec_rsp.size();
// insert the reshape instruction
auto rsp_ins = p.insert_instruction(std::next(input), op::reshape{rsp_lens}, input);
// replace the original reshape with slice
int64_t i = 0;
for(auto in : vec_rsp)
{
p.replace_instruction(
in, op::slice{{axis}, {i * dim_size}, {(i + 1) * dim_size}}, rsp_ins);
++i;
}
}
};
struct find_split_transpose
{
auto matcher() const
{
return match::name("transpose")(match::arg(0)(match::name("slice").bind("slice")))
.bind("trans");
}
void apply(program& p, match::matcher_result r) const
{
auto slc = r.instructions["slice"];
auto trans = r.instructions["trans"];
auto input = slc->inputs().front();
auto split_outputs = get_splits(input);
if(split_outputs.empty())
{
return;
}
std::vector<instruction_ref> vec_trans(split_outputs.size());
std::transform(split_outputs.begin(), split_outputs.end(), vec_trans.begin(), [](auto i) {
assert(i->outputs().size() == 1);
return i->outputs().front();
});
// all transpose are the same
auto perm = any_cast<op::transpose>(trans->get_operator()).dims;
if(!same_ops(vec_trans))
{
return;
}
// insert an transpose instruction
auto tr = p.insert_instruction(std::next(input), op::transpose{perm}, input);
// compute the axis in the slice
auto axis = any_cast<op::slice>(slc->get_operator()).axes.front();
auto it = std::find(perm.begin(), perm.end(), axis);
assert(it != perm.end());
auto axis_new = static_cast<int64_t>(std::distance(perm.begin(), it));
for(auto in : split_outputs)
{
auto oper = any_cast<op::slice>(in->get_operator());
auto starts = oper.starts;
auto ends = oper.ends;
auto tr_orig = in->outputs().front();
p.replace_instruction(tr_orig, op::slice{{axis_new}, starts, ends}, tr);
}
}
};
void simplify_algebra::apply(program& p) const void simplify_algebra::apply(program& p) const
{ {
// Run simplifications multiple times // Run simplifications multiple times
...@@ -732,7 +858,9 @@ void simplify_algebra::apply(program& p) const ...@@ -732,7 +858,9 @@ void simplify_algebra::apply(program& p) const
find_rsqrt{}, find_rsqrt{},
find_concat_op{}, find_concat_op{},
find_split_concat{}, find_split_concat{},
find_splits{}); find_splits{},
find_split_reshape{},
find_split_transpose{});
dead_code_elimination{}.apply(p); dead_code_elimination{}.apply(p);
} }
} }
......
...@@ -737,22 +737,20 @@ struct find_conv_bias_relu ...@@ -737,22 +737,20 @@ struct find_conv_bias_relu
void fuse_ops::apply(program& p) const void fuse_ops::apply(program& p) const
{ {
// clang-format off
match::find_matches(p, find_gelu{}, find_gelu_new{}); match::find_matches(p, find_gelu{}, find_gelu_new{});
run_passes(p, {dead_code_elimination{}}); run_passes(p, {dead_code_elimination{}});
match::find_matches(p, find_triadd{}); match::find_matches(p, find_triadd{});
match::find_matches(p, match::find_matches(p,
find_conv_bias_relu{ctx}, find_conv_bias_relu{ctx},
find_conv_bias{ctx}, find_conv_bias{ctx},
find_add_gelu{}, find_add_gelu{},
find_add_gelu_new{}, find_add_gelu_new{},
find_mul_add{}, find_mul_add{},
find_mul_add_relu{}, find_mul_add_relu{},
find_add_unary{"gpu::relu", hip_add_relu{}, hip_triadd_relu{}}, find_add_unary{"gpu::relu", hip_add_relu{}, hip_triadd_relu{}},
find_add_unary{"gpu::sigmoid", hip_add_sigmoid{}, hip_triadd_sigmoid{}}, find_add_unary{"gpu::sigmoid", hip_add_sigmoid{}, hip_triadd_sigmoid{}},
find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}}, find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}},
find_add_clip{} find_add_clip{});
);
// clang-format on // clang-format on
} }
......
...@@ -806,22 +806,18 @@ TEST_CASE(simplify_split_add_relu_reshape) ...@@ -806,22 +806,18 @@ TEST_CASE(simplify_split_add_relu_reshape)
migraphx::program p2; migraphx::program p2;
{ {
auto b = migraphx::op::broadcast{1, {3, 2, 4}}; auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto r = migraphx::op::reshape{{3, 4}}; auto input = p2.add_parameter("input", s);
auto input = p2.add_parameter("input", s); auto one = p2.add_literal(1);
auto one = p2.add_literal(1); auto two = p2.add_literal(2);
auto two = p2.add_literal(2); auto concat = p2.add_instruction(migraphx::op::concat{0}, one, two);
auto concat = p2.add_instruction(migraphx::op::concat{0}, one, two); auto concatb = p2.add_instruction(b, concat);
auto concatb = p2.add_instruction(b, concat); auto sum = p2.add_instruction(migraphx::op::add{}, input, concatb);
auto sum = p2.add_instruction(migraphx::op::add{}, input, concatb); auto relu = p2.add_instruction(migraphx::op::relu{}, sum);
auto relu = p2.add_instruction(migraphx::op::relu{}, sum); auto rsp = p2.add_instruction(migraphx::op::reshape{{3, 8}}, relu);
auto slice1 = p2.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, relu); auto slc1 = p2.add_instruction(migraphx::op::slice{{1}, {0}, {4}}, rsp);
auto cont1 = p2.add_instruction(migraphx::op::contiguous{}, slice1); auto slc2 = p2.add_instruction(migraphx::op::slice{{1}, {4}, {8}}, rsp);
auto reshape1 = p2.add_instruction(r, cont1); auto add = p2.add_instruction(migraphx::op::add{}, slc1, slc2);
auto slice2 = p2.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, relu);
auto cont2 = p2.add_instruction(migraphx::op::contiguous{}, slice2);
auto reshape2 = p2.add_instruction(r, cont2);
auto add = p2.add_instruction(migraphx::op::add{}, reshape1, reshape2);
p2.add_instruction(pass_op{}, add); p2.add_instruction(pass_op{}, add);
} }
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
...@@ -1387,4 +1383,235 @@ TEST_CASE(simplify_conv_horiz_grouped_extra2) ...@@ -1387,4 +1383,235 @@ TEST_CASE(simplify_conv_horiz_grouped_extra2)
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
} }
TEST_CASE(reorder_reshape_slice)
{
std::vector<int64_t> perm0 = {0, 2, 1, 3};
std::vector<int64_t> perm1 = {0, 2, 3, 1};
auto create_p1 = [&](std::size_t batch_size) {
migraphx::program p1;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
auto input = p1.add_parameter("input", s);
auto slc0 = p1.add_instruction(migraphx::op::slice{{2}, {0}, {640}}, input);
auto slc1 = p1.add_instruction(migraphx::op::slice{{2}, {640}, {1280}}, input);
auto slc2 = p1.add_instruction(migraphx::op::slice{{2}, {1280}, {1920}}, input);
auto c0 = p1.add_instruction(migraphx::op::contiguous{}, slc0);
auto c1 = p1.add_instruction(migraphx::op::contiguous{}, slc1);
auto c2 = p1.add_instruction(migraphx::op::contiguous{}, slc2);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 128, 10, 64};
auto r0 = p1.add_instruction(migraphx::op::reshape{lens}, c0);
auto r1 = p1.add_instruction(migraphx::op::reshape{lens}, c1);
auto r2 = p1.add_instruction(migraphx::op::reshape{lens}, c2);
auto t0 = p1.add_instruction(migraphx::op::transpose{perm0}, r0);
auto t1 = p1.add_instruction(migraphx::op::transpose{perm0}, r1);
auto t2 = p1.add_instruction(migraphx::op::transpose{perm1}, r2);
auto sum = p1.add_instruction(migraphx::op::add{}, t0, t1);
auto ret = p1.add_instruction(migraphx::op::dot{}, sum, t2);
p1.add_return({ret});
return p1;
};
auto create_p2 = [&](std::size_t batch_size) {
migraphx::program p2;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
auto input = p2.add_parameter("input", s);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 128, 30, 64};
auto r = p2.add_instruction(migraphx::op::reshape{lens}, input);
auto slc0 = p2.add_instruction(migraphx::op::slice{{2}, {0}, {10}}, r);
auto slc1 = p2.add_instruction(migraphx::op::slice{{2}, {10}, {20}}, r);
auto slc2 = p2.add_instruction(migraphx::op::slice{{2}, {20}, {30}}, r);
auto t0 = p2.add_instruction(migraphx::op::transpose{perm0}, slc0);
auto t1 = p2.add_instruction(migraphx::op::transpose{perm0}, slc1);
auto t2 = p2.add_instruction(migraphx::op::transpose{perm1}, slc2);
auto sum = p2.add_instruction(migraphx::op::add{}, t0, t1);
auto ret = p2.add_instruction(migraphx::op::dot{}, sum, t2);
p2.add_return({ret});
return p2;
};
auto test = [&](std::size_t batch_size) {
auto p1 = create_p1(batch_size);
run_pass(p1);
auto p2 = create_p2(batch_size);
EXPECT(p1.sort() == p2.sort());
};
test(1);
test(4);
test(8);
}
TEST_CASE(reorder_reshape_slice_invalid_axis)
{
auto create_p1 = [](std::size_t batch_size) {
migraphx::program p1;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 129, 96}};
std::vector<int64_t> perm0 = {0, 2, 1, 3};
std::vector<int64_t> perm1 = {0, 2, 3, 1};
auto input = p1.add_parameter("input", s);
auto slc0 = p1.add_instruction(migraphx::op::slice{{2}, {0}, {32}}, input);
auto slc1 = p1.add_instruction(migraphx::op::slice{{2}, {32}, {64}}, input);
auto slc2 = p1.add_instruction(migraphx::op::slice{{2}, {64}, {96}}, input);
auto c0 = p1.add_instruction(migraphx::op::contiguous{}, slc0);
auto c1 = p1.add_instruction(migraphx::op::contiguous{}, slc1);
auto c2 = p1.add_instruction(migraphx::op::contiguous{}, slc2);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 43, 3, 32};
auto r0 = p1.add_instruction(migraphx::op::reshape{lens}, c0);
auto r1 = p1.add_instruction(migraphx::op::reshape{lens}, c1);
auto r2 = p1.add_instruction(migraphx::op::reshape{lens}, c2);
auto t0 = p1.add_instruction(migraphx::op::transpose{perm0}, r0);
auto t1 = p1.add_instruction(migraphx::op::transpose{perm0}, r1);
auto t2 = p1.add_instruction(migraphx::op::transpose{perm1}, r2);
auto sum = p1.add_instruction(migraphx::op::add{}, t0, t1);
auto ret = p1.add_instruction(migraphx::op::dot{}, sum, t2);
p1.add_return({ret});
return p1;
};
auto test = [&](std::size_t batch_size) {
auto p1 = create_p1(batch_size);
auto p2 = p1;
run_pass(p1);
EXPECT(p1.sort() == p2.sort());
};
test(4);
test(8);
}
TEST_CASE(reorder_reshape_slice_diff_dims)
{
auto create_p1 = [](std::size_t batch_size) {
migraphx::program p1;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 96, 96}};
std::vector<int64_t> perm0 = {0, 2, 1, 3};
std::vector<int64_t> perm1 = {0, 2, 3, 1};
auto input = p1.add_parameter("input", s);
auto slc0 = p1.add_instruction(migraphx::op::slice{{2}, {0}, {32}}, input);
auto slc1 = p1.add_instruction(migraphx::op::slice{{2}, {32}, {64}}, input);
auto slc2 = p1.add_instruction(migraphx::op::slice{{2}, {64}, {96}}, input);
auto c0 = p1.add_instruction(migraphx::op::contiguous{}, slc0);
auto c1 = p1.add_instruction(migraphx::op::contiguous{}, slc1);
auto c2 = p1.add_instruction(migraphx::op::contiguous{}, slc2);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 32, 3, 32};
std::vector<int64_t> lens1 = {static_cast<int64_t>(batch_size), 48, 2, 32};
auto r0 = p1.add_instruction(migraphx::op::reshape{lens}, c0);
auto r1 = p1.add_instruction(migraphx::op::reshape{lens}, c1);
auto r2 = p1.add_instruction(migraphx::op::reshape{lens1}, c2);
p1.add_return({r0, r1, r2});
return p1;
};
auto test = [&](std::size_t batch_size) {
auto p1 = create_p1(batch_size);
auto p2 = p1;
run_pass(p1);
EXPECT(p1.sort() == p2.sort());
};
test(4);
test(8);
}
TEST_CASE(reorder_slice_trans)
{
std::vector<int64_t> perm = {0, 2, 1};
auto create_p1 = [&](std::size_t batch_size) {
migraphx::program p1;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
auto input = p1.add_parameter("input", s);
auto slc0 = p1.add_instruction(migraphx::op::slice{{2}, {0}, {640}}, input);
auto slc1 = p1.add_instruction(migraphx::op::slice{{2}, {640}, {1280}}, input);
auto slc2 = p1.add_instruction(migraphx::op::slice{{2}, {1280}, {1920}}, input);
auto t0 = p1.add_instruction(migraphx::op::transpose{perm}, slc0);
auto t1 = p1.add_instruction(migraphx::op::transpose{perm}, slc1);
auto t2 = p1.add_instruction(migraphx::op::transpose{perm}, slc2);
auto sum = p1.add_instruction(migraphx::op::add{}, t0, t1);
auto ret = p1.add_instruction(migraphx::op::mul{}, sum, t2);
p1.add_return({ret});
return p1;
};
auto create_p2 = [&](std::size_t batch_size) {
migraphx::program p2;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
auto input = p2.add_parameter("input", s);
auto r = p2.add_instruction(migraphx::op::transpose{perm}, input);
auto slc0 = p2.add_instruction(migraphx::op::slice{{1}, {0}, {640}}, r);
auto slc1 = p2.add_instruction(migraphx::op::slice{{1}, {640}, {1280}}, r);
auto slc2 = p2.add_instruction(migraphx::op::slice{{1}, {1280}, {1920}}, r);
auto sum = p2.add_instruction(migraphx::op::add{}, slc0, slc1);
auto ret = p2.add_instruction(migraphx::op::mul{}, sum, slc2);
p2.add_return({ret});
return p2;
};
auto test = [&](std::size_t batch_size) {
auto p1 = create_p1(batch_size);
run_pass(p1);
auto p2 = create_p2(batch_size);
EXPECT(p1.sort() == p2.sort());
};
test(1);
test(8);
}
TEST_CASE(reorder_slice_trans_diff_perm)
{
auto create_p1 = [](std::size_t batch_size) {
migraphx::program p1;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
std::vector<int64_t> perm0 = {0, 2, 1};
std::vector<int64_t> perm1 = {0, 1, 2};
auto input = p1.add_parameter("input", s);
auto slc0 = p1.add_instruction(migraphx::op::slice{{2}, {0}, {640}}, input);
auto slc1 = p1.add_instruction(migraphx::op::slice{{2}, {640}, {1280}}, input);
auto slc2 = p1.add_instruction(migraphx::op::slice{{2}, {1280}, {1920}}, input);
auto t0 = p1.add_instruction(migraphx::op::transpose{perm0}, slc0);
auto t1 = p1.add_instruction(migraphx::op::transpose{perm0}, slc1);
auto t2 = p1.add_instruction(migraphx::op::transpose{perm1}, slc2);
auto sum = p1.add_instruction(migraphx::op::add{}, t0, t1);
auto ret = p1.add_instruction(migraphx::op::dot{}, sum, t2);
p1.add_return({ret});
return p1;
};
auto test = [&](std::size_t batch_size) {
auto p1 = create_p1(batch_size);
run_pass(p1);
auto p2 = p1;
EXPECT(p1.sort() == p2.sort());
};
test(1);
test(4);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment