"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "37fbabf5da68b37f99b0ee8069ce832df75c0c4b"
Unverified Commit 17485202 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Optimize resize and where operators (#784)



* code backup

* clang format

* add a matcher related to the special resize case for optimization

* clang format

* code backup

* clang format

* code backup

* remove unnecessary code

* add optimization for the where op

* clang format

* fix cppcheck error

* add a unit test for optimize resize

* clang format

* remove unnecessary header include

* code backup

* clang format

* add unit tests for optimizing resize

* clang format

* add more unit test for optimizing where op

* clang format

* remove unnecessary code

* add one more optimzation to remove contiguous

* clang format

* add a pointwise requirement

* clang format

* fix cppcheck error

* add one more unit test

* fixed a bug

* clang format

* remove unnecessary code

* clang format

* fix a build error

* fix review comments

* clang format

* fix a review comments

* clang format

* code refinement

* clang format

* refine more code

* refine more code

* fix a bug related to reshape_cont optimization

* clang format

* fix a review comment

* removed an unnecessary comment

* refine code according to comments

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent f7befe50
......@@ -701,6 +701,13 @@ inline auto has_attribute(const std::string& name)
[=](instruction_ref ins) { return ins->get_operator().attributes().contains(name); });
}
template <class... Ms>
auto pointwise(Ms... ms)
{
return match::has_attribute("pointwise")(match::any_of(match::nargs(1), match::nargs(2)),
ms...);
}
} // namespace match
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -17,7 +17,6 @@
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/algorithm.hpp>
......@@ -39,13 +38,6 @@ auto conv_const_weights()
match::args(match::any(), match::is_constant().bind("w")));
}
template <class... Ms>
auto pointwise(Ms... ms)
{
return match::has_attribute("pointwise")(match::any_of(match::nargs(1), match::nargs(2)),
ms...);
}
auto reduction() { return match::name_contains("reduce"); }
struct find_mul_conv
......@@ -287,7 +279,7 @@ struct find_concat_op
auto matcher() const
{
return match::name("concat")(match::any_of[match::inputs()](
match::any_of(pointwise(), match::name("broadcast")), match::used_once()));
match::any_of(match::pointwise(), match::name("broadcast")), match::used_once()));
}
template <class Iterator>
......@@ -407,8 +399,8 @@ struct find_splits
{
auto matcher() const
{
return match::any(match::any_of[match::outputs()](
match::name("slice")(match::any_of[match::outputs()](pointwise(), reduction()))));
return match::any(match::any_of[match::outputs()](match::name("slice")(
match::any_of[match::outputs()](match::pointwise(), reduction()))));
}
static std::vector<std::vector<instruction_ref>>
......
#include <iterator>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
......@@ -318,11 +319,233 @@ struct find_nested_concat
}
};
struct find_resize
{
auto matcher() const
{
return match::name("gather")(
match::args(match::name("reshape").bind("data"), match::is_constant().bind("ind")));
}
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto ins_rsp = r.instructions["data"];
auto ins_ind = r.instructions["ind"];
// resize input shape
if(ins_rsp->get_shape().lens().size() != 1)
{
return;
}
// resize output shape
const auto& in_shape = ins_rsp->inputs().front()->get_shape();
const auto& out_shape = ins->get_shape();
// check if output shape is multiple of input shape
const auto& in_lens = in_shape.lens();
const auto& out_lens = out_shape.lens();
if(in_lens.size() != out_lens.size())
{
return;
}
// output shape must be multiple of input shape
std::vector<bool> is_multi(in_lens.size());
std::transform(
in_lens.begin(), in_lens.end(), out_lens.begin(), is_multi.begin(), [](auto x, auto y) {
return (y % x == 0);
});
if(not std::all_of(is_multi.begin(), is_multi.end(), [](auto b) { return b; }))
{
return;
}
// output must be multiple of inputs
std::vector<std::size_t> scales(in_lens.size());
std::transform(
in_lens.begin(), in_lens.end(), out_lens.begin(), scales.begin(), [](auto x, auto y) {
return y / x;
});
// if ind is not constant, cannot optimize
std::vector<int> vec_ind;
auto arg_ind = ins_ind->eval();
if(arg_ind.empty())
{
return;
}
arg_ind.visit([&](auto v) { vec_ind.assign(v.begin(), v.end()); });
std::vector<int> index(out_shape.elements());
std::iota(index.begin(), index.end(), 0);
if(not std::all_of(index.begin(), index.end(), [&](auto i) {
auto out_idx = out_shape.multi(i);
auto in_idx = out_idx;
std::transform(out_idx.begin(),
out_idx.end(),
scales.begin(),
in_idx.begin(),
[&](auto io, auto scale) { return io - (io % scale); });
return vec_ind[i] == vec_ind[out_shape.index(in_idx)];
}))
{
return;
}
// wrap up shapes for multibroadcast
std::vector<std::pair<std::size_t, std::size_t>> dim_scales;
std::transform(in_lens.begin(),
in_lens.end(),
out_lens.begin(),
std::back_inserter(dim_scales),
[](auto x, auto y) { return std::make_pair(x, y / x); });
std::vector<int64_t> in_dims;
std::vector<int64_t> out_dims;
for(auto& isp : dim_scales)
{
in_dims.push_back(isp.first);
out_dims.push_back(isp.first * isp.second);
if(isp.first == 1 or isp.second == 1)
{
continue;
}
out_dims.back() = isp.first;
in_dims.push_back(1);
out_dims.push_back(isp.second);
}
auto in_rsp = ins_rsp->inputs().front();
auto rsp_data = p.insert_instruction(
ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp);
auto mb_rsp = p.insert_instruction(
ins_rsp, migraphx::make_op("multibroadcast", {{"output_lens", out_dims}}), rsp_data);
auto std_mb = p.insert_instruction(ins, migraphx::make_op("contiguous"), mb_rsp);
std::vector<int64_t> rsp_dims(out_lens.begin(), out_lens.end());
p.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), std_mb);
}
};
struct find_where_op
{
auto matcher() const
{
return match::name("gather")(
match::args(match::name("reshape")(match::arg(0)(match::name("concat").bind("data"))),
match::is_constant().bind("ind")));
}
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto concat = r.instructions["data"];
auto ins_ind = r.instructions["ind"];
std::vector<bool> vec_ind;
auto arg_ind = ins_ind->eval();
arg_ind.visit([&](auto v) { vec_ind.assign(v.begin(), v.end()); });
// ind has to be the same value
auto val = vec_ind.front();
if(not std::all_of(vec_ind.begin(), vec_ind.end(), [&](auto v) { return (v == val); }))
{
return;
}
// concat axis must be 0
auto op = any_cast<op::concat>(concat->get_operator());
if(op.axis != 0)
{
return;
}
// check concat inputs, it has to be 2 and have the same shape
const auto& inputs = concat->inputs();
if(inputs.size() != 2)
{
return;
}
if(inputs.at(0)->get_shape() != inputs.at(1)->get_shape())
{
return;
}
if(inputs.at(0)->get_shape().lens() != ins_ind->get_shape().lens())
{
return;
}
if(val)
{
p.replace_instruction(ins, inputs.at(0));
}
else
{
p.replace_instruction(ins, inputs.at(1));
}
}
};
struct find_reshape_cont
{
auto matcher() const
{
return match::pointwise(
match::nargs(2),
match::either_arg(0, 1)(
match::name("reshape")(match::args(match::name("contiguous").bind("cont")))
.bind("rsp"),
match::any()));
}
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto ins_cont = r.instructions["cont"];
auto in_ins = r.instructions["rsp"];
auto cont_input = ins_cont->inputs().front();
auto lens = cont_input->get_shape().lens();
std::vector<int64_t> dims(lens.begin(), lens.end());
if(in_ins->get_shape() != ins->get_shape())
{
return;
}
if(not std::all_of(ins->inputs().begin(), ins->inputs().end(), [](auto i) {
return i->get_shape().standard();
}))
{
return;
}
auto out_lens = ins->get_shape().lens();
std::vector<int64_t> out_dims(out_lens.begin(), out_lens.end());
std::vector<instruction_ref> inputs;
for(const auto& in : ins->inputs())
{
if(in == in_ins)
{
inputs.push_back(cont_input);
}
else
{
inputs.push_back(
p.insert_instruction(ins, make_op("reshape", {{"dims", dims}}), in));
}
}
auto out = p.insert_instruction(ins, ins->get_operator(), inputs);
p.replace_instruction(ins, make_op("reshape", {{"dims", out_dims}}), out);
}
};
void simplify_reshapes::apply(module& p) const
{
for(int i = 0; i < 2; i++)
{
match::find_matches(p,
find_where_op{},
find_resize{},
find_reshape_cont{},
find_nop_reshapes{},
find_reshaper{},
find_transpose{},
......
......@@ -517,4 +517,423 @@ TEST_CASE(double_slice_multi_axes)
EXPECT(m1 == m2);
}
TEST_CASE(optimize_resize)
{
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}};
auto create_resize_module = [&] {
migraphx::module m;
auto inx = m.add_parameter("X", sx);
migraphx::shape si{migraphx::shape::int32_type, {1, 2, 4, 6}};
std::vector<int> ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3,
3, 3, 2, 2, 2, 3, 3, 3, 0, 0, 0, 1, 1, 1, 0, 0,
0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3};
auto li = m.add_literal(migraphx::literal(si, ind));
auto lrsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), inx);
auto gr = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
auto r = m.add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), gr);
m.add_return({r});
return m;
};
auto m1 = create_resize_module();
run_pass(m1);
auto create_optimized_module = [&] {
migraphx::module m;
auto inx = m.add_parameter("X", sx);
std::vector<int64_t> dims = {1, 1, 2, 1, 2, 1};
auto rspx = m.add_instruction(migraphx::make_op("reshape", {{"dims", dims}}), inx);
std::vector<int64_t> mb_dims = {1, 2, 2, 2, 2, 3};
auto mbx = m.add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", mb_dims}}), rspx);
auto std_mb = m.add_instruction(migraphx::make_op("contiguous"), mbx);
std::vector<int64_t> orig_dims = {1, 2, 4, 6};
auto rmb = m.add_instruction(migraphx::make_op("reshape", {{"dims", orig_dims}}), std_mb);
auto r = m.add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), rmb);
m.add_return({r});
return m;
};
EXPECT(m1 == create_optimized_module());
}
TEST_CASE(optimize_resize_ind_not_apply)
{
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}};
auto create_resize_module = [&] {
migraphx::module m;
auto inx = m.add_parameter("X", sx);
migraphx::shape si{migraphx::shape::int32_type, {1, 2, 4, 6}};
std::vector<int> ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 2, 2, 2, 3,
3, 3, 2, 2, 2, 3, 3, 3, 0, 0, 0, 1, 1, 1, 0, 0,
0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3};
auto li = m.add_literal(migraphx::literal(si, ind));
auto lrsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), inx);
auto gr = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
auto r = m.add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), gr);
m.add_return({r});
return m;
};
auto m1 = create_resize_module();
run_pass(m1);
EXPECT(m1 == create_resize_module());
}
TEST_CASE(optimize_resize_rsp_dim_1)
{
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}};
auto create_resize_module = [&] {
migraphx::module m;
auto inx = m.add_parameter("X", sx);
migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 3, 2}};
std::vector<int> ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1,
2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3};
auto li = m.add_literal(migraphx::literal(si, ind));
auto lrsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2}}}), inx);
auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
m.add_return({r});
return m;
};
auto m = create_resize_module();
run_pass(m);
EXPECT(m == create_resize_module());
}
TEST_CASE(optimize_resize_ndims_unequal)
{
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}};
migraphx::shape sy{migraphx::shape::float_type, {1, 1, 4, 3, 2}};
auto create_resize_module = [&] {
migraphx::module m;
auto inx = m.add_parameter("X", sx);
auto iny = m.add_parameter("Y", sy);
migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 3, 2}};
std::vector<int> ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1,
2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3};
auto li = m.add_literal(migraphx::literal(si, ind));
auto lrsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), inx);
auto gr = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
auto r = m.add_instruction(migraphx::make_op("sub"), iny, gr);
m.add_return({r});
return m;
};
auto m = create_resize_module();
run_pass(m);
EXPECT(m == create_resize_module());
}
TEST_CASE(optimize_resize_ind_non_brcst)
{
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 3, 2}};
migraphx::shape sy{migraphx::shape::float_type, {1, 1, 4, 6}};
auto create_resize_module = [&] {
migraphx::module m;
auto inx = m.add_parameter("X", sx);
auto iny = m.add_parameter("Y", sy);
migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 6}};
std::vector<int> ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1,
2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3};
auto li = m.add_literal(migraphx::literal(si, ind));
auto lrsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", {6}}}), inx);
auto gr = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
auto r = m.add_instruction(migraphx::make_op("sub"), iny, gr);
m.add_return({r});
return m;
};
auto m = create_resize_module();
run_pass(m);
EXPECT(m == create_resize_module());
}
TEST_CASE(optimize_resize_ind_non_const)
{
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 3, 2}};
migraphx::shape sy{migraphx::shape::float_type, {1, 1, 4, 6}};
auto create_resize_module = [&] {
migraphx::module m;
auto inx = m.add_parameter("X", sx);
auto iny = m.add_parameter("Y", sy);
migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 6}};
auto li = m.add_parameter("ind", si);
auto lrsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", {6}}}), inx);
auto gr = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
auto r = m.add_instruction(migraphx::make_op("sub"), iny, gr);
m.add_return({r});
return m;
};
auto m = create_resize_module();
run_pass(m);
EXPECT(m == create_resize_module());
}
TEST_CASE(optimize_where_true)
{
migraphx::shape s{migraphx::shape::float_type, {1, 1, 3, 2}};
auto create_where_module = [&](bool cond) {
migraphx::module m;
auto inx = m.add_parameter("X", s);
auto iny = m.add_parameter("Y", s);
migraphx::shape si{migraphx::shape::bool_type, {1, 1, 3, 2}};
std::vector<char> idata(si.elements(), static_cast<char>(cond));
auto li = m.add_literal(migraphx::literal(si, idata));
auto data = m.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny);
auto data_1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data);
auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li);
m.add_return({r});
return m;
};
auto create_opt_module = [&](std::string name) {
migraphx::module m;
auto in = m.add_parameter(std::move(name), s);
m.add_return({in});
return m;
};
auto m = create_where_module(true);
run_pass(m);
EXPECT(m == create_opt_module("X"));
auto m1 = create_where_module(false);
run_pass(m1);
EXPECT(m1 == create_opt_module("Y"));
}
TEST_CASE(where_different_cond_values)
{
auto create_where_module = [] {
migraphx::module m;
migraphx::shape s{migraphx::shape::float_type, {1, 1, 3, 2}};
auto inx = m.add_parameter("X", s);
auto iny = m.add_parameter("Y", s);
migraphx::shape si{migraphx::shape::bool_type, {1, 1, 3, 2}};
std::vector<char> idata = {1, 1, 0, 1, 0, 1};
auto li = m.add_literal(migraphx::literal(si, idata));
auto data = m.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny);
auto data_1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data);
auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li);
m.add_return({r});
return m;
};
auto m = create_where_module();
run_pass(m);
EXPECT(m == create_where_module());
}
TEST_CASE(where_axis_nonzero)
{
auto create_where_module = [] {
migraphx::module m;
migraphx::shape s{migraphx::shape::float_type, {1, 1, 3, 2}};
auto inx = m.add_parameter("X", s);
auto iny = m.add_parameter("Y", s);
migraphx::shape si{migraphx::shape::bool_type, {1, 1, 3, 2}};
std::vector<char> idata(6, 1);
auto li = m.add_literal(migraphx::literal(si, idata));
auto data = m.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), inx, iny);
auto data_1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data);
auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li);
m.add_return({r});
return m;
};
auto m = create_where_module();
run_pass(m);
EXPECT(m == create_where_module());
}
TEST_CASE(where_three_concat_inputs)
{
auto create_where_module = [] {
migraphx::module m;
migraphx::shape s{migraphx::shape::float_type, {1, 1, 3, 2}};
auto inx = m.add_parameter("X", s);
auto iny = m.add_parameter("Y", s);
migraphx::shape si{migraphx::shape::bool_type, {1, 1, 3, 2}};
std::vector<char> idata(6, 1);
auto li = m.add_literal(migraphx::literal(si, idata));
auto data = m.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny, inx);
auto data_1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {18}}}), data);
auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li);
m.add_return({r});
return m;
};
auto m = create_where_module();
run_pass(m);
EXPECT(m == create_where_module());
}
TEST_CASE(where_three_inputs_diff_shapes)
{
auto create_where_module = [] {
migraphx::module m;
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 3, 2}};
migraphx::shape sy{migraphx::shape::float_type, {2, 1, 3, 2}};
auto inx = m.add_parameter("X", sx);
auto iny = m.add_parameter("Y", sy);
migraphx::shape si{migraphx::shape::bool_type, {1, 1, 3, 2}};
std::vector<char> idata(6, 1);
auto li = m.add_literal(migraphx::literal(si, idata));
auto data = m.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny);
auto data_1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {18}}}), data);
auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li);
m.add_return({r});
return m;
};
auto m = create_where_module();
run_pass(m);
EXPECT(m == create_where_module());
}
TEST_CASE(where_three_lens_diff)
{
auto create_where_module = [] {
migraphx::module m;
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 3, 2}};
migraphx::shape sy{migraphx::shape::float_type, {1, 1, 3, 2}};
auto inx = m.add_parameter("X", sx);
auto iny = m.add_parameter("Y", sy);
migraphx::shape si{migraphx::shape::bool_type, {1, 1, 6}};
std::vector<char> idata(6, 1);
auto li = m.add_literal(migraphx::literal(si, idata));
auto data = m.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny);
auto data_1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data);
auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li);
m.add_return({r});
return m;
};
auto m = create_where_module();
run_pass(m);
EXPECT(m == create_where_module());
}
TEST_CASE(reshape_cont)
{
auto create_module = [] {
migraphx::module m;
migraphx::shape sx{migraphx::shape::float_type, {1, 4, 1}};
migraphx::shape sy{migraphx::shape::float_type, {2, 2, 2, 6}};
auto inx = m.add_parameter("x", sx);
auto iny = m.add_parameter("y", sy);
auto mb_inx = m.add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 4, 6}}}), inx);
auto std_inx = m.add_instruction(migraphx::make_op("contiguous"), mb_inx);
auto rsp =
m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 6}}}), std_inx);
auto r = m.add_instruction(migraphx::make_op("add"), rsp, iny);
m.add_return({r});
return m;
};
auto m1 = create_module();
run_pass(m1);
auto create_opt_module = [] {
migraphx::module m;
migraphx::shape sx{migraphx::shape::float_type, {1, 4, 1}};
migraphx::shape sy{migraphx::shape::float_type, {2, 2, 2, 6}};
auto inx = m.add_parameter("x", sx);
auto iny = m.add_parameter("y", sy);
auto mb_inx = m.add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 4, 6}}}), inx);
auto rsp_iny = m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 4, 6}}}), iny);
auto sum = m.add_instruction(migraphx::make_op("add"), mb_inx, rsp_iny);
auto r = m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 6}}}), sum);
m.add_return({r});
return m;
};
EXPECT(m1 == create_opt_module());
}
TEST_CASE(reshape_input_non_std)
{
auto create_module = [] {
migraphx::module m;
migraphx::shape sx{migraphx::shape::float_type, {1, 4, 1}};
migraphx::shape sy{migraphx::shape::float_type, {2, 6, 2, 2}};
auto inx = m.add_parameter("x", sx);
auto iny = m.add_parameter("y", sy);
auto mb_inx = m.add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 4, 6}}}), inx);
auto std_inx = m.add_instruction(migraphx::make_op("contiguous"), mb_inx);
auto rsp =
m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 6}}}), std_inx);
auto ty = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), iny);
auto r = m.add_instruction(migraphx::make_op("add"), rsp, ty);
m.add_return({r});
return m;
};
auto m1 = create_module();
run_pass(m1);
EXPECT(m1 == create_module());
}
TEST_CASE(reshape_cont_nonpw)
{
auto create_module = [] {
migraphx::module m;
migraphx::shape sx{migraphx::shape::float_type, {1, 4, 1}};
migraphx::shape sy{migraphx::shape::float_type, {2, 2, 2, 6}};
auto inx = m.add_parameter("x", sx);
auto iny = m.add_parameter("y", sy);
auto mb_inx = m.add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 4, 6}}}), inx);
auto std_inx = m.add_instruction(migraphx::make_op("contiguous"), mb_inx);
auto rsp =
m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 6}}}), std_inx);
auto r = m.add_instruction(migraphx::make_op("convolution"), rsp, iny);
m.add_return({r});
return m;
};
auto m1 = create_module();
run_pass(m1);
EXPECT(m1 == create_module());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment