Commit 4a39a0f7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into add-conv_bn_add-test

parents 5564172e bb827865
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
......@@ -10,7 +11,9 @@
void run_pass(migraphx::module& m)
{
migraphx::run_passes(m, {migraphx::eliminate_pad{}, migraphx::dead_code_elimination{}});
migraphx::run_passes(
m,
{migraphx::normalize_ops{}, migraphx::eliminate_pad{}, migraphx::dead_code_elimination{}});
}
migraphx::instruction_ref
......@@ -66,15 +69,15 @@ TEST_CASE(rewrite_pad)
auto om1 = l1->get_operator().to_value();
auto om2 = l2->get_operator().to_value();
EXPECT(op0["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{1, 1});
EXPECT(om1["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{1, 1});
EXPECT(om2["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{1, 1});
EXPECT(op0["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{1, 1, 1, 1});
EXPECT(om1["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{1, 1, 1, 1});
EXPECT(om2["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{1, 1, 1, 1});
EXPECT(std::none_of(
m.begin(), m.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; }));
}
TEST_CASE(rewrite_pad_im2col_asymetric)
TEST_CASE(rewrite_pad_im2col_asymmetric)
{
migraphx::module m;
......@@ -95,10 +98,10 @@ TEST_CASE(rewrite_pad_im2col_asymetric)
EXPECT(l0->get_shape() == s0);
auto op0 = l0->get_operator().to_value();
EXPECT(op0["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{0, 0});
EXPECT(op0["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{0, 0, 2, 2});
run_pass(m);
EXPECT(std::any_of(
EXPECT(std::none_of(
m.begin(), m.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; }));
}
......
......@@ -8,4 +8,34 @@ TEST_CASE(generate)
EXPECT(migraphx::generate_literal(s, 1) != migraphx::generate_argument(s, 0));
}
TEST_CASE(fill_tuple)
{
migraphx::shape s0{migraphx::shape::float_type, {4, 4, 1, 1}};
migraphx::shape s1{migraphx::shape::int32_type, {2, 3}};
migraphx::shape s2{migraphx::shape::bool_type, {3, 2}};
migraphx::shape s({s0, s1, s2});
auto arg = migraphx::fill_argument(s, 1);
const auto& args = arg.get_sub_objects();
EXPECT(args.at(0) == migraphx::fill_argument(s0, 1));
EXPECT(args.at(1) == migraphx::fill_argument(s1, 1));
EXPECT(args.at(2) == migraphx::fill_argument(s2, 1));
}
TEST_CASE(generate_tuple)
{
migraphx::shape s0{migraphx::shape::float_type, {4, 4, 1, 1}};
migraphx::shape s1{migraphx::shape::int32_type, {2, 3}};
migraphx::shape s2{migraphx::shape::bool_type, {3, 2}};
migraphx::shape s({s0, s1, s2});
auto arg = migraphx::generate_argument(s, 1);
const auto& args = arg.get_sub_objects();
EXPECT(args.at(0) == migraphx::generate_argument(s0, 1));
EXPECT(args.at(1) == migraphx::generate_argument(s1, 1));
EXPECT(args.at(2) == migraphx::generate_argument(s2, 1));
EXPECT(args.at(0) != migraphx::generate_argument(s0, 0));
EXPECT(args.at(1) != migraphx::generate_argument(s1, 2));
EXPECT(args.at(2) != migraphx::generate_argument(s2, 0));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -13,15 +13,16 @@
#include <migraphx/op/tanh.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
void run_lowering(migraphx::program& p)
void run_lowering(migraphx::program& p, bool offload_copy = false)
{
auto ctx = migraphx::gpu::context{};
migraphx::run_passes(*p.get_main_module(),
{migraphx::auto_contiguous{},
migraphx::gpu::lowering{&ctx, false},
migraphx::gpu::lowering{&ctx, offload_copy},
migraphx::dead_code_elimination{},
migraphx::eliminate_contiguous{"gpu::contiguous"},
migraphx::dead_code_elimination{}});
......@@ -67,4 +68,41 @@ TEST_CASE(tanh_shape)
EXPECT(p1 == p2);
}
TEST_CASE(no_copy_dead_param)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto x = mm->add_parameter("x", s);
mm->add_parameter("y", s);
auto sum = mm->add_instruction(migraphx::make_op("add"), x, x);
mm->add_return({sum});
return p;
};
auto create_gpu_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto x = mm->add_parameter("x", s);
mm->add_parameter("y", s);
auto xb = mm->add_instruction(migraphx::make_op("hip::allocate", {{"shape", to_value(s)}}));
auto gx = mm->add_instruction(migraphx::make_op("hip::copy_to_gpu"), x, xb);
auto ab = mm->add_instruction(migraphx::make_op("hip::allocate", {{"shape", to_value(s)}}));
auto sum = mm->add_instruction(migraphx::make_op("gpu::add"), gx, gx, ab);
auto r = mm->add_instruction(migraphx::make_op("hip::copy_from_gpu"), sum);
mm->add_return({r});
return p;
};
auto p1 = create_program();
auto p2 = create_gpu_program();
run_lowering(p1, true);
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -6,8 +6,11 @@
#include <migraphx/gpu/kernel.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_pointwise.hpp>
// NOLINTNEXTLINE
const std::string write_2s = R"__migraphx__(
......@@ -53,7 +56,7 @@ using namespace migraphx;
extern "C" {
__global__ void kernel(void* x, void* y)
{
make_tensors(x, y)([](auto xt, auto yt) __device__ {
make_tensors()(x, y)([](auto xt, auto yt) __device__ {
auto idx = make_index();
const auto stride = idx.nglobal();
for(index_int i = idx.global; i < xt.get_shape().elements(); i += stride)
......@@ -69,24 +72,52 @@ int main() {}
)__migraphx__";
migraphx::src_file make_src_file(const std::string& name, const std::string& content)
// NOLINTNEXTLINE
const std::string check_define = R"__migraphx__(
#ifndef __DEFINE__
#error __DEFINE__ was not defined
#endif
int main() {}
)__migraphx__";
// NOLINTNEXTLINE
const std::string unused_param = R"__migraphx__(
extern "C" {
__global__ void kernel(void* x, void* y)
{}
}
int main() {}
)__migraphx__";
// NOLINTNEXTLINE
const std::string incorrect_program = R"__migraphx__(
extern "C" {
__global__ void kernel(void* x)
{
return {name, std::make_pair(content.data(), content.data() + content.size())};
x += y;
}
}
std::string get_device_name()
int main() {}
)__migraphx__";
migraphx::src_file make_src_file(const std::string& name, const std::string& content)
{
hipDeviceProp_t props{};
int device;
EXPECT(hipGetDevice(&device) == hipSuccess);
EXPECT(hipGetDeviceProperties(&props, device) == hipSuccess);
return "gfx" + std::to_string(props.gcnArch);
return {name, std::make_pair(content.data(), content.data() + content.size())};
}
TEST_CASE(simple_compile_hip)
{
auto binaries = migraphx::gpu::compile_hip_src(
{make_src_file("main.cpp", write_2s)}, "", get_device_name());
{make_src_file("main.cpp", write_2s)}, "", migraphx::gpu::get_device_name());
EXPECT(binaries.size() == 1);
migraphx::argument input{{migraphx::shape::int8_type, {5}}};
......@@ -100,10 +131,45 @@ TEST_CASE(simple_compile_hip)
EXPECT(migraphx::all_of(data, [](auto x) { return x == 2; }));
}
auto check_target(const std::string& arch)
{
auto define = "__" + arch + "__";
auto content = migraphx::replace_string(check_define, "__DEFINE__", define);
return migraphx::gpu::compile_hip_src({make_src_file("main.cpp", content)}, "", arch);
}
TEST_CASE(compile_target)
{
EXPECT(not check_target("gfx900").empty());
EXPECT(not check_target("gfx906").empty());
}
TEST_CASE(compile_errors)
{
EXPECT(test::throws([&] {
migraphx::gpu::compile_hip_src(
{make_src_file("main.cpp", incorrect_program)}, "", migraphx::gpu::get_device_name());
}));
}
TEST_CASE(compile_warnings)
{
auto compile = [](const std::string& params) {
return migraphx::gpu::compile_hip_src(
{make_src_file("main.cpp", unused_param)}, params, migraphx::gpu::get_device_name());
};
EXPECT(not compile("").empty());
EXPECT(not compile("-Wunused-parameter -Wno-error").empty());
EXPECT(not compile("-Wno-unused-parameter -Werror").empty());
EXPECT(test::throws([&] { compile("-Werror=unused-parameter"); }));
EXPECT(test::throws([&] { compile("-Wunused-parameter -Werror"); }));
}
TEST_CASE(code_object_hip)
{
auto binaries = migraphx::gpu::compile_hip_src(
{make_src_file("main.cpp", add_2s_binary)}, "", get_device_name());
{make_src_file("main.cpp", add_2s_binary)}, "", migraphx::gpu::get_device_name());
EXPECT(binaries.size() == 1);
migraphx::shape input{migraphx::shape::int8_type, {5}};
......@@ -159,4 +225,26 @@ TEST_CASE(compile_code_object_hip)
EXPECT(result == output_literal.get_argument());
}
TEST_CASE(compile_pointwise)
{
migraphx::shape input{migraphx::shape::float_type, {5, 2}};
migraphx::gpu::context ctx;
auto co = migraphx::gpu::compile_pointwise(ctx, {input, input}, "[](auto x) { return x + 1; }");
migraphx::program p;
auto* mm = p.get_main_module();
auto input_literal = migraphx::generate_literal(input);
auto output_literal = migraphx::transform(input_literal, [](auto x) { return x + 1; });
auto x = mm->add_literal(input_literal);
auto y = mm->add_parameter("output", input);
mm->add_instruction(co, x, y);
p.compile(migraphx::gpu::target{}, migraphx::compile_options{});
auto result =
migraphx::gpu::from_gpu(p.eval({{"output", migraphx::gpu::allocate_gpu(input)}}).front());
EXPECT(result == output_literal.get_argument());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include "migraphx/instruction_ref.hpp"
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/adjust_allocation.hpp>
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
void run_passes(migraphx::module& m)
{
auto ctx = migraphx::gpu::context{};
migraphx::run_passes(m,
{migraphx::auto_contiguous{},
migraphx::gpu::lowering{&ctx, false},
migraphx::dead_code_elimination{},
migraphx::gpu::pack_int8_args{},
migraphx::dead_code_elimination{}});
}
bool get_int8_x4_format()
{
bool int8_x4_format = true;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
auto ctx = migraphx::gpu::context{};
rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4);
#endif
return int8_x4_format;
}
TEST_CASE(quant_dot)
{
auto create_module = [] {
migraphx::module m("test");
migraphx::shape m1_shape{migraphx::shape::int8_type, {5, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {5, 7}};
auto l1 = m.add_parameter("a", m1_shape);
auto l2 = m.add_parameter("b", m2_shape);
auto l3 = m.add_parameter("c", m3_shape);
auto r =
migraphx::add_apply_alpha_beta(m, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 1);
m.add_return({r});
return m;
};
auto create_optimized_int8_x4 = [](bool int8_x4) {
migraphx::module m("test");
migraphx::shape m1_shape{migraphx::shape::int8_type, {5, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {5, 7}};
auto l1 = m.add_parameter("a", m1_shape);
auto l2 = m.add_parameter("b", m2_shape);
auto l3 = m.add_parameter("c", m3_shape);
auto beta = m.add_literal(1);
auto output = m.add_parameter("test:#output_0", m3_shape);
auto gemm_alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m3_shape)}}));
auto packa = l2;
if(int8_x4)
{
auto alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m2_shape)}}));
packa = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), l2, alloc);
}
auto gemm =
m.add_instruction(migraphx::make_op("gpu::quant_gemm", {{"int8_x4_format", int8_x4}}),
l1,
packa,
gemm_alloc);
auto beta_broadcast = m.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", m3_shape.lens()}}), beta);
auto beta_alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m3_shape)}}));
auto beta_contiguous =
m.add_instruction(migraphx::make_op("gpu::contiguous"), beta_broadcast, beta_alloc);
auto mul_alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m3_shape)}}));
auto m3_beta =
m.add_instruction(migraphx::make_op("gpu::mul"), l3, beta_contiguous, mul_alloc);
auto gemm_add = m.add_instruction(migraphx::make_op("gpu::add"), gemm, m3_beta, output);
m.add_return({gemm_add});
return m;
};
auto m1 = create_module();
run_passes(m1);
bool flag = get_int8_x4_format();
auto m2 = create_optimized_int8_x4(flag);
EXPECT(m1 == m2);
}
TEST_CASE(quant_dot_trans)
{
auto create_module = [] {
migraphx::module m("test");
migraphx::shape s1{migraphx::shape::int8_type, {3, 2, 8, 5}};
migraphx::shape s2{migraphx::shape::int8_type, {3, 2, 7, 8}};
auto l1 = m.add_parameter("a", s1);
auto tl1 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
auto l2 = m.add_parameter("b", s2);
auto tl2 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto r = migraphx::add_apply_alpha_beta(m, {tl1, tl2}, migraphx::make_op("quant_dot"), 3);
m.add_return({r});
return m;
};
auto create_optimized_int8_x4 = [](bool int8_x4) {
migraphx::module m("test");
migraphx::shape s1{migraphx::shape::int8_type, {3, 2, 8, 5}};
migraphx::shape s2{migraphx::shape::int8_type, {3, 2, 7, 8}};
migraphx::shape s3{migraphx::shape::int32_type, {3, 2, 5, 7}};
auto l1 = m.add_parameter("a", s1);
auto l2 = m.add_parameter("b", s2);
auto alpha = m.add_literal(3);
auto output = m.add_parameter("test:#output_0", s3);
auto tl1 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
migraphx::shape ts1{migraphx::shape::int8_type, {3, 2, 5, 8}};
auto alloca = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}}));
auto conta = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl1, alloca);
auto tl2 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 8, 7}};
auto allocb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}}));
auto contb = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl2, allocb);
auto alpha_broadcast = m.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", conta->get_shape().lens()}}), alpha);
auto alpha_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate",
{{"shape",
migraphx::to_value(migraphx::shape(migraphx::shape::int32_type, {3, 2, 5, 8}))}}));
auto alpha_contiguous =
m.add_instruction(migraphx::make_op("gpu::contiguous"), alpha_broadcast, alpha_alloc);
// alpha = int32 and tl1 = int8, convert tl1 to int32 for multiplication and then convert
// back result to int8
auto tl1_convert_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(alpha_contiguous->get_shape())}}));
auto tl1_convert = m.add_instruction(
migraphx::make_op("gpu::convert", {{"target_type", alpha->get_shape().type()}}),
conta,
tl1_convert_alloc);
auto mul_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(tl1_convert->get_shape())}}));
auto tl1_alpha_int32 = m.add_instruction(
migraphx::make_op("gpu::mul"), alpha_contiguous, tl1_convert, mul_alloc);
// convert mul_res to int8
auto tl1_alpha_int8_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(conta->get_shape())}}));
auto tl1_alpha_int8 = m.add_instruction(
migraphx::make_op("gpu::convert", {{"target_type", conta->get_shape().type()}}),
tl1_alpha_int32,
tl1_alpha_int8_alloc);
auto packb = contb;
if(int8_x4)
{
auto allocpb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}}));
packb = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), contb, allocpb);
}
auto gemm =
m.add_instruction(migraphx::make_op("gpu::quant_gemm", {{"int8_x4_format", int8_x4}}),
tl1_alpha_int8,
packb,
output);
m.add_return({gemm});
return m;
};
auto m1 = create_module();
bool flag = get_int8_x4_format();
auto m2 = create_optimized_int8_x4(flag);
run_passes(m1);
EXPECT(m1 == m2);
}
TEST_CASE(quant_dot_pad)
{
auto create_module = [] {
migraphx::module m("test");
migraphx::shape s1{migraphx::shape::int8_type, {5, 6}};
migraphx::shape s2{migraphx::shape::int8_type, {6, 7}};
migraphx::shape s3{migraphx::shape::int32_type, {5, 7}};
auto l1 = m.add_parameter("a", s1);
auto l2 = m.add_parameter("b", s2);
auto l3 = m.add_parameter("c", s3);
auto r =
migraphx::add_apply_alpha_beta(m, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 1);
m.add_return({r});
return m;
};
auto create_optimized_int8_x4 = [](bool int8_x4) {
migraphx::module m("test");
migraphx::shape s1{migraphx::shape::int8_type, {5, 6}};
migraphx::shape ps1{migraphx::shape::int8_type, {5, 8}};
migraphx::shape s2{migraphx::shape::int8_type, {6, 7}};
migraphx::shape ps2{migraphx::shape::int8_type, {8, 7}};
migraphx::shape s3{migraphx::shape::int32_type, {5, 7}};
auto l1 = m.add_parameter("a", s1);
auto l2 = m.add_parameter("b", s2);
auto l3 = m.add_parameter("c", s3);
auto beta = m.add_literal(1);
auto output = m.add_parameter("test:#output_0", s3);
auto pl1 = l1;
auto packa = l2;
migraphx::instruction_ref pl2{};
if(int8_x4)
{
auto po1 = m.insert_instruction(
l1, migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps1)}}));
pl1 = m.add_instruction(
migraphx::make_op("gpu::pad", {{"mode", 0}, {"pads", {0, 2, 0, 0}}, {"value", 0}}),
l1,
po1);
auto po2 = m.insert_instruction(
l2, migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps2)}}));
pl2 = m.insert_instruction(
std::next(l2),
migraphx::make_op("gpu::pad", {{"mode", 0}, {"pads", {2, 0, 0, 0}}, {"value", 0}}),
l2,
po2);
}
auto gemm_alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(s3)}}));
if(int8_x4)
{
auto alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps2)}}));
packa = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), pl2, alloc);
}
auto gemm =
m.add_instruction(migraphx::make_op("gpu::quant_gemm", {{"int8_x4_format", int8_x4}}),
pl1,
packa,
gemm_alloc);
auto beta_broadcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}), beta);
auto beta_alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(s3)}}));
auto beta_contiguous =
m.add_instruction(migraphx::make_op("gpu::contiguous"), beta_broadcast, beta_alloc);
auto mul_alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(s3)}}));
auto m3_beta =
m.add_instruction(migraphx::make_op("gpu::mul"), l3, beta_contiguous, mul_alloc);
auto gemm_add = m.add_instruction(migraphx::make_op("gpu::add"), gemm, m3_beta, output);
m.add_return({gemm_add});
return m;
};
auto m1 = create_module();
bool flag = get_int8_x4_format();
auto m2 = create_optimized_int8_x4(flag);
run_passes(m1);
EXPECT(m1 == m2);
}
TEST_CASE(quant_dot_trans_pad)
{
auto create_module = [] {
migraphx::module m("test");
migraphx::shape s1{migraphx::shape::int8_type, {3, 2, 9, 5}};
migraphx::shape s2{migraphx::shape::int8_type, {3, 2, 7, 9}};
auto l1 = m.add_parameter("a", s1);
auto tl1 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
auto l2 = m.add_parameter("b", s2);
auto tl2 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto r = migraphx::add_apply_alpha_beta(m, {tl1, tl2}, migraphx::make_op("quant_dot"), 3);
m.add_return({r});
return m;
};
auto create_optimized_int8_x4 = [](bool int8_x4) {
migraphx::module m("test");
migraphx::shape s1{migraphx::shape::int8_type, {3, 2, 9, 5}};
migraphx::shape ps1{migraphx::shape::int8_type, {3, 2, 5, 12}};
migraphx::shape s2{migraphx::shape::int8_type, {3, 2, 7, 9}};
migraphx::shape ps2{migraphx::shape::int8_type, {3, 2, 12, 7}};
migraphx::shape s3{migraphx::shape::int32_type, {3, 2, 5, 7}};
auto l1 = m.add_parameter("a", s1);
auto l2 = m.add_parameter("b", s2);
auto alpha = m.add_literal(3);
auto output = m.add_parameter("test:#output_0", s3);
auto tl1 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
migraphx::shape ts1{migraphx::shape::int8_type, {3, 2, 5, 9}};
auto ta = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}}));
auto conta = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl1, ta);
auto tl2 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 9, 7}};
auto tb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}}));
migraphx::instruction_ref ptb{};
if(int8_x4)
{
ptb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps2)}}));
}
auto contb = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl2, tb);
auto pb = contb;
if(int8_x4)
{
pb = m.add_instruction(
migraphx::make_op("gpu::pad", {{"mode", 0}, {"pads", {0, 0, 3, 0, 0, 0, 0, 0}}}),
contb,
ptb);
}
auto alpha_broadcast = m.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", conta->get_shape().lens()}}), alpha);
auto alpha_alloc = m.add_instruction(
migraphx::make_op("hip::allocate",
{{"shape",
migraphx::to_value(migraphx::shape(migraphx::shape::int32_type,
conta->get_shape().lens()))}}));
auto alpha_contiguous =
m.add_instruction(migraphx::make_op("gpu::contiguous"), alpha_broadcast, alpha_alloc);
// alpha = int32 and tl1 = int8, convert tl1 to int32 for multiplication and then convert
// back result to int8
auto tl1_convert_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(alpha_contiguous->get_shape())}}));
auto tl1_convert = m.add_instruction(
migraphx::make_op("gpu::convert", {{"target_type", alpha->get_shape().type()}}),
conta,
tl1_convert_alloc);
auto mul_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(tl1_convert->get_shape())}}));
auto tl1_alpha_int32 = m.add_instruction(
migraphx::make_op("gpu::mul"), alpha_contiguous, tl1_convert, mul_alloc);
// convert mul_res to int8
auto tl1_alpha_int8_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(conta->get_shape())}}));
migraphx::instruction_ref pta{};
if(int8_x4)
{
pta = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps1)}}));
}
auto tl1_alpha_int8 = m.add_instruction(
migraphx::make_op("gpu::convert", {{"target_type", conta->get_shape().type()}}),
tl1_alpha_int32,
tl1_alpha_int8_alloc);
auto pa = tl1_alpha_int8;
if(int8_x4)
{
pa = m.add_instruction(
migraphx::make_op("gpu::pad", {{"mode", 0}, {"pads", {0, 0, 0, 3, 0, 0, 0, 0}}}),
tl1_alpha_int8,
pta);
}
auto packb = pb;
if(int8_x4)
{
auto allocpb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps2)}}));
packb = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), pb, allocpb);
}
auto gemm = m.add_instruction(
migraphx::make_op("gpu::quant_gemm", {{"int8_x4_format", int8_x4}}), pa, packb, output);
m.add_return({gemm});
return m;
};
auto m1 = create_module();
bool flag = get_int8_x4_format();
auto m2 = create_optimized_int8_x4(flag);
run_passes(m1);
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -63,13 +63,12 @@ TEST_CASE(int8_quantization)
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sa{migraphx::shape::float_type, {5, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
migraphx::shape sc{migraphx::shape::float_type, {5, 8}};
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc);
mm->add_instruction(migraphx::op::dot{}, pa, pb, pc);
mm->add_instruction(migraphx::op::dot{}, pa, pb);
return p;
};
......@@ -77,12 +76,11 @@ TEST_CASE(int8_quantization)
{
auto p = create_program();
migraphx::parameter_map m;
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sa{migraphx::shape::float_type, {5, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
migraphx::shape sc{migraphx::shape::float_type, {5, 8}};
m["a"] = migraphx::generate_argument(sa);
m["b"] = migraphx::generate_argument(sb);
m["c"] = migraphx::generate_argument(sc);
std::vector<float> ref_result;
migraphx::target ref_t = migraphx::ref::target{};
run_prog(p, ref_t, m, ref_result);
......
......@@ -79,8 +79,7 @@ struct pass_op
return {};
return inputs.front();
}
int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
int output_alias(const std::vector<migraphx::shape>& s) const { return s.empty() ? -1 : 0; }
};
struct mod_pass_op
......
......@@ -6,9 +6,14 @@
#include <functional>
#include <iostream>
#include <sstream>
#include <type_traits>
#include <unordered_map>
#include <vector>
#ifdef __linux__
#include <unistd.h>
#endif
#ifndef MIGRAPHX_GUARD_TEST_TEST_HPP
#define MIGRAPHX_GUARD_TEST_TEST_HPP
......@@ -79,8 +84,8 @@ struct function
}
};
template <class Iterator>
inline std::ostream& stream_range(std::ostream& s, Iterator start, Iterator last)
template <class Stream, class Iterator>
inline Stream& stream_range(Stream& s, Iterator start, Iterator last)
{
if(start != last)
{
......@@ -90,22 +95,17 @@ inline std::ostream& stream_range(std::ostream& s, Iterator start, Iterator last
return s;
}
inline std::ostream& operator<<(std::ostream& s, std::nullptr_t)
template <class Stream>
inline Stream& operator<<(Stream& s, std::nullptr_t)
{
s << "nullptr";
return s;
}
template <class T>
inline std::ostream& operator<<(std::ostream& s, const std::vector<T>& v)
{
s << "{ ";
stream_range(s, v.begin(), v.end());
s << "}";
return s;
}
inline std::ostream& operator<<(std::ostream& s, const std::vector<bool>& v)
template <class Stream,
class Range,
class = typename std::enable_if<not std::is_convertible<Range, std::string>{}>::type>
inline auto operator<<(Stream& s, const Range& v) -> decltype(stream_range(s, v.begin(), v.end()))
{
s << "{ ";
stream_range(s, v.begin(), v.end());
......@@ -264,6 +264,32 @@ struct capture
}
};
enum class color
{
reset = 0,
bold = 1,
underlined = 4,
fg_red = 31,
fg_green = 32,
fg_yellow = 33,
fg_blue = 34,
fg_default = 39,
bg_red = 41,
bg_green = 42,
bg_yellow = 43,
bg_blue = 44,
bg_default = 49
};
inline std::ostream& operator<<(std::ostream& os, const color& c)
{
#ifndef _WIN32
static const bool use_color = isatty(STDOUT_FILENO) != 0;
if(use_color)
return os << "\033[" << static_cast<std::size_t>(c) << "m";
#endif
return os;
}
template <class T, class F>
void failed(T x, const char* msg, const char* func, const char* file, int line, F f)
{
......@@ -271,7 +297,7 @@ void failed(T x, const char* msg, const char* func, const char* file, int line,
{
std::cout << func << std::endl;
std::cout << file << ":" << line << ":" << std::endl;
std::cout << " FAILED: " << msg << " "
std::cout << color::bold << color::fg_red << " FAILED: " << color::reset << msg << " "
<< "[ " << x << " ]" << std::endl;
f();
}
......@@ -315,7 +341,7 @@ auto near(T px, U py, double ptol = 1e-6f)
using string_map = std::unordered_map<std::string, std::vector<std::string>>;
template <class Keyword>
string_map parse(std::vector<std::string> as, Keyword keyword)
string_map generic_parse(std::vector<std::string> as, Keyword keyword)
{
string_map result;
......@@ -331,19 +357,22 @@ string_map parse(std::vector<std::string> as, Keyword keyword)
{
flag = f.front();
result[flag]; // Ensure the flag exists
flag = f.back();
}
}
return result;
}
using test_case = std::function<void()>;
inline auto& get_test_cases()
{
// NOLINTNEXTLINE
static std::vector<std::pair<std::string, std::function<void()>>> cases;
static std::vector<std::pair<std::string, test_case>> cases;
return cases;
}
inline void add_test_case(std::string name, std::function<void()> f)
inline void add_test_case(std::string name, test_case f)
{
get_test_cases().emplace_back(std::move(name), std::move(f));
}
......@@ -357,37 +386,243 @@ struct auto_register_test_case
}
};
inline void run_test_case(const std::string& name, const std::function<void()>& f)
struct failure_error
{
std::cout << "[ RUN ] " << name << std::endl;
f();
std::cout << "[ COMPLETE ] " << name << std::endl;
}
};
inline void run(int argc, const char* argv[])
[[noreturn]] inline void fail() { throw failure_error{}; }
struct driver
{
std::vector<std::string> as(argv + 1, argv + argc);
driver()
{
add_flag({"--help", "-h"}, "Show help");
add_flag({"--list", "-l"}, "List all test cases");
add_flag({"--continue", "-c"}, "Continue after failure");
add_flag({"--quiet", "-q"}, "Don't print out extra output");
}
struct argument
{
std::vector<std::string> flags = {};
std::string help = "";
int nargs = 1;
};
auto args = parse(as, [](auto &&) -> std::vector<std::string> { return {}; });
auto cases = args[""];
if(cases.empty())
void add_arg(const std::vector<std::string>& flags, const std::string& help = "")
{
for(auto&& tc : get_test_cases())
run_test_case(tc.first, tc.second);
arguments.push_back(argument{flags, help, 1});
}
else
void add_flag(const std::vector<std::string>& flags, const std::string& help = "")
{
std::unordered_map<std::string, std::function<void()>> m(get_test_cases().begin(),
get_test_cases().end());
for(auto&& name : cases)
arguments.push_back(argument{flags, help, 0});
}
void show_help(const std::string& exe) const
{
std::cout << std::endl;
std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl;
std::cout << " ";
std::cout << exe << " <test-case>... <options>" << std::endl;
std::cout << std::endl;
std::cout << color::fg_yellow << "ARGS:" << color::reset << std::endl;
std::cout << " ";
std::cout << color::fg_green << "<test-case>..." << color::reset;
std::cout << std::endl;
std::cout << " "
<< "Test case name to run" << std::endl;
std::cout << std::endl;
std::cout << color::fg_yellow << "OPTIONS:" << color::reset << std::endl;
for(auto&& arg : arguments)
{
auto f = m.find(name);
if(f == m.end())
std::cout << "[ ERROR ] Test case '" << name << "' not found." << std::endl;
std::string prefix = " ";
std::cout << color::fg_green;
for(const std::string& a : arg.flags)
{
std::cout << prefix;
std::cout << a;
prefix = ", ";
}
std::cout << color::reset << std::endl;
std::cout << " " << arg.help << std::endl;
}
}
std::ostream& out() const
{
struct null_buffer : std::streambuf
{
virtual int overflow(int c) override { return c; }
};
static null_buffer buffer;
static std::ostream null_stream(&buffer);
if(quiet)
return null_stream;
return std::cout;
}
string_map parse(int argc, const char* argv[]) const
{
std::vector<std::string> args(argv + 1, argv + argc);
string_map keys;
for(auto&& arg : arguments)
{
for(auto&& flag : arg.flags)
{
keys[flag] = {arg.flags.front()};
if(arg.nargs == 0)
keys[flag].push_back("");
}
}
auto result = generic_parse(args, [&](auto&& s) -> std::vector<std::string> {
if(keys.count(s) > 0)
return keys[s];
else
run_test_case(name, f->second);
return {};
});
result["__exe__"].push_back(argv[0]);
return result;
}
static std::string create_command(const string_map& args)
{
std::stringstream ss;
ss << args.at("__exe__").front();
if(args.count("") > 0)
{
for(auto&& arg : args.at(""))
ss << " \"" << arg << "\"";
}
for(auto&& p : args)
{
if(p.first == "__exe__")
continue;
if(p.first.empty())
continue;
ss << " " << p.first;
for(auto&& arg : p.second)
ss << " \"" << arg << "\"";
}
return ss.str();
}
static std::string fork(const std::string& name, string_map args)
{
std::string msg;
args[""] = {name};
args.erase("--continue");
args["--quiet"];
auto cmd = create_command(args);
auto r = std::system(cmd.c_str()); // NOLINT
if(r != 0)
msg = "Exited with " + std::to_string(r);
return msg;
}
void run_test_case(const std::string& name, const test_case& f, const string_map& args)
{
ran++;
out() << color::fg_green << "[ RUN ] " << color::reset << color::bold << name
<< color::reset << std::endl;
std::string msg;
if(args.count("--continue") > 0)
{
msg = fork(name, args);
}
else
{
try
{
f();
}
catch(const failure_error&)
{
msg = "Test failure";
}
}
if(msg.empty())
{
out() << color::fg_green << "[ COMPLETE ] " << color::reset << color::bold << name
<< color::reset << std::endl;
}
else
{
failed.push_back(name);
out() << color::fg_red << "[ FAILED ] " << color::reset << color::bold << name
<< color::reset << ": " << color::fg_yellow << msg << color::reset << std::endl;
}
}
void run(int argc, const char* argv[])
{
auto args = parse(argc, argv);
if(args.count("--help") > 0)
{
show_help(args.at("__exe__").front());
return;
}
if(args.count("--list") > 0)
{
for(auto&& tc : get_test_cases())
out() << tc.first << std::endl;
return;
}
if(args.count("--quiet") > 0)
quiet = true;
auto cases = args[""];
if(cases.empty())
{
for(auto&& tc : get_test_cases())
run_test_case(tc.first, tc.second, args);
}
else
{
std::unordered_map<std::string, test_case> m(get_test_cases().begin(),
get_test_cases().end());
for(auto&& iname : cases)
{
for(auto&& name : get_case_names(iname))
{
auto f = m.find(name);
if(f == m.end())
{
out() << color::fg_red << "[ ERROR ] Test case '" << name
<< "' not found." << color::reset << std::endl;
failed.push_back(name);
}
else
run_test_case(name, f->second, args);
}
}
}
out() << color::fg_green << "[==========] " << color::fg_yellow << ran << " tests ran"
<< color::reset << std::endl;
if(not failed.empty())
{
out() << color::fg_red << "[ FAILED ] " << color::fg_yellow << failed.size()
<< " tests failed" << color::reset << std::endl;
for(auto&& name : failed)
out() << color::fg_red << "[ FAILED ] " << color::fg_yellow << name
<< color::reset << std::endl;
std::exit(1);
}
}
std::function<std::vector<std::string>(const std::string&)> get_case_names =
[](const std::string& name) -> std::vector<std::string> { return {name}; };
std::vector<argument> arguments = {};
std::vector<std::string> failed = {};
std::size_t ran = 0;
bool quiet = false;
};
inline void run(int argc, const char* argv[])
{
driver d{};
d.run(argc, argv);
}
} // namespace test
......@@ -404,7 +639,7 @@ inline void run(int argc, const char* argv[])
__PRETTY_FUNCTION__, \
__FILE__, \
__LINE__, \
&std::abort)
&test::fail)
// NOLINTNEXTLINE
#define STATUS(...) EXPECT((__VA_ARGS__) == 0)
......
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::inline_module{}, migraphx::dead_code_elimination{}});
}
TEST_CASE(cannot_inline_both)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {2, 3}};
auto x = mm->add_parameter("x", sd);
std::vector<float> one(sd.elements(), 1);
std::vector<float> two(sd.elements(), 2);
auto* then_smod = p.create_module("then_smod");
auto l1 = then_smod->add_literal(migraphx::literal{sd, one});
auto r1 = then_smod->add_instruction(migraphx::make_op("add"), x, l1);
then_smod->add_return({r1});
auto* else_smod = p.create_module("else_smod");
auto l2 = else_smod->add_literal(migraphx::literal{sd, two});
auto r2 = else_smod->add_instruction(migraphx::make_op("mul"), x, l2);
else_smod->add_return({r2});
migraphx::shape s_cond{migraphx::shape::bool_type, {1}};
auto cond = mm->add_parameter("cond", s_cond);
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_smod, else_smod});
mm->add_return({ret});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_program());
}
TEST_CASE(cannot_inline_one)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
migraphx::shape s{migraphx::shape::float_type, {5}};
auto cond = mm->add_parameter("cond", cond_s);
auto x = mm->add_parameter("x", s);
auto* then_mod = p.create_module("If_0_if");
std::vector<float> data1 = {1, 2, 3, 4, 5};
auto l1 = then_mod->add_literal(migraphx::literal(s, data1));
then_mod->add_return({l1, x});
auto* else_mod = p.create_module("If_0_else");
std::vector<float> data2 = {5, 4, 3, 2, 1};
auto l2 = else_mod->add_literal(migraphx::literal(s, data2));
auto s2 = else_mod->add_instruction(migraphx::make_op("add"), x, l2);
else_mod->add_return({s2, l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_program());
}
TEST_CASE(inline_if_test)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
auto cond = mm->add_literal(migraphx::literal(sc, {1}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s, ones);
std::vector<float> rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946};
auto l2 = mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s);
auto sm = mm->add_instruction(migraphx::make_op("add"), l1, x);
auto y = mm->add_parameter("y", s);
auto* then_mod = p.create_module("If_5_if");
auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, sm);
then_mod->add_outline(s);
then_mod->add_return({rt});
auto* else_mod = p.create_module("If_5_else");
auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
else_mod->add_return({re});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p;
};
auto create_inline = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s, ones);
std::vector<float> rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946};
mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s);
auto sm = mm->add_instruction(migraphx::make_op("add"), l1, x);
mm->add_parameter("y", s);
auto r = mm->add_instruction(migraphx::make_op("add"), x, sm);
mm->add_return({r});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_inline());
}
TEST_CASE(inline_else_test)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
auto cond = mm->add_literal(migraphx::literal(sc, {0}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s, ones);
std::vector<float> rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946};
auto l2 = mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto* then_mod = p.create_module("If_5_if");
auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1);
then_mod->add_return({rt});
auto* else_mod = p.create_module("If_5_else");
else_mod->add_parameter("e", s);
else_mod->add_literal(migraphx::literal(s, ones));
auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
else_mod->add_return({re});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
return p;
};
auto create_inline = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f);
mm->add_literal(s, ones);
std::vector<float> rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946};
auto l2 = mm->add_literal(s, rand);
mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
mm->add_parameter("e", s);
auto r = mm->add_instruction(migraphx::make_op("mul"), y, l2);
mm->add_return({r});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_inline());
}
TEST_CASE(if_recursive_test)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
migraphx::shape xs{migraphx::shape::float_type, {2, 3}};
migraphx::shape ys{migraphx::shape::float_type, {3, 3}};
std::vector<float> datax = {1, 2, 3, 4, 5, 6};
std::vector<float> datay = {8, 7, 6, 5, 4, 3, 2, 1, 0};
auto lx = mm->add_literal(migraphx::literal(xs, datax));
auto ly = mm->add_literal(migraphx::literal(ys, datay));
auto cond = mm->add_literal(migraphx::literal(cond_s, {0}));
auto x1 = mm->add_parameter("x1", xs);
auto x2 = mm->add_parameter("x2", xs);
auto y2 = mm->add_parameter("y2", ys);
auto cond1 = mm->add_parameter("cond", cond_s);
auto* then_mod = p.create_module("If_5_if");
auto l1 = then_mod->add_literal(migraphx::literal(ys, datay));
auto a1 = then_mod->add_instruction(migraphx::make_op("add"), x1, lx);
then_mod->add_return({a1, l1});
auto* then_mod1 = p.create_module("If_6_if");
auto l11 = then_mod1->add_literal(migraphx::literal(ys, datay));
auto a11 = then_mod1->add_instruction(migraphx::make_op("add"), x2, lx);
then_mod1->add_return({a11, l11});
auto* else_mod1 = p.create_module("If_6_else");
auto l21 = else_mod1->add_literal(migraphx::literal(xs, datax));
auto a21 = else_mod1->add_instruction(migraphx::make_op("mul"), y2, ly);
else_mod1->add_return({l21, a21});
auto* else_mod = p.create_module("If_5_else");
auto l2 = else_mod->add_literal(migraphx::literal(xs, datax));
auto a2 =
else_mod->add_instruction(migraphx::make_op("if"), {cond1}, {then_mod1, else_mod1});
auto a3 =
else_mod->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), a2);
else_mod->add_return({l2, a3});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r});
return p;
};
auto create_inline = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
migraphx::shape xs{migraphx::shape::float_type, {2, 3}};
migraphx::shape ys{migraphx::shape::float_type, {3, 3}};
std::vector<float> datax = {1, 2, 3, 4, 5, 6};
std::vector<float> datay = {8, 7, 6, 5, 4, 3, 2, 1, 0};
auto lx = mm->add_literal(migraphx::literal(xs, datax));
auto ly = mm->add_literal(migraphx::literal(ys, datay));
mm->add_parameter("x1", xs);
auto x2 = mm->add_parameter("x2", xs);
auto y2 = mm->add_parameter("y2", ys);
auto cond1 = mm->add_parameter("cond", cond_s);
auto* then_mod1 = p.create_module("If_6_if");
auto l11 = then_mod1->add_literal(migraphx::literal(ys, datay));
auto a11 = then_mod1->add_instruction(migraphx::make_op("add"), x2, lx);
then_mod1->add_return({a11, l11});
auto* else_mod1 = p.create_module("If_6_else");
auto l21 = else_mod1->add_literal(migraphx::literal(xs, datax));
auto a21 = else_mod1->add_instruction(migraphx::make_op("mul"), y2, ly);
else_mod1->add_return({l21, a21});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond1}, {then_mod1, else_mod1});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_inline());
}
TEST_CASE(if_recursive_cond0_test)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
migraphx::shape xs{migraphx::shape::float_type, {2, 3}};
migraphx::shape ys{migraphx::shape::float_type, {3, 3}};
std::vector<float> datax = {1, 2, 3, 4, 5, 6};
std::vector<float> datay = {8, 7, 6, 5, 4, 3, 2, 1, 0};
auto lx = mm->add_literal(migraphx::literal(xs, datax));
auto ly = mm->add_literal(migraphx::literal(ys, datay));
auto cond = mm->add_literal(migraphx::literal(cond_s, {0}));
auto x1 = mm->add_parameter("x1", xs);
auto x2 = mm->add_parameter("x2", xs);
auto y2 = mm->add_parameter("y2", ys);
auto* then_mod = p.create_module("If_5_if");
auto l1 = then_mod->add_literal(migraphx::literal(ys, datay));
auto a1 = then_mod->add_instruction(migraphx::make_op("add"), x1, lx);
then_mod->add_return({a1, l1});
auto* then_mod1 = p.create_module("If_6_if");
auto l11 = then_mod1->add_literal(migraphx::literal(ys, datay));
auto a11 = then_mod1->add_instruction(migraphx::make_op("add"), x2, lx);
then_mod1->add_return({a11, l11});
auto* else_mod1 = p.create_module("If_6_else");
auto l21 = else_mod1->add_literal(migraphx::literal(xs, datax));
auto a21 = else_mod1->add_instruction(migraphx::make_op("mul"), y2, ly);
else_mod1->add_return({l21, a21});
auto* else_mod = p.create_module("If_5_else");
auto l2 = else_mod->add_literal(migraphx::literal(xs, datax));
auto a2 =
else_mod->add_instruction(migraphx::make_op("if"), {cond}, {then_mod1, else_mod1});
auto a3 =
else_mod->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), a2);
else_mod->add_return({l2, a3});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r});
return p;
};
auto create_inline = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
migraphx::shape xs{migraphx::shape::float_type, {2, 3}};
migraphx::shape ys{migraphx::shape::float_type, {3, 3}};
std::vector<float> datax = {1, 2, 3, 4, 5, 6};
std::vector<float> datay = {8, 7, 6, 5, 4, 3, 2, 1, 0};
mm->add_literal(migraphx::literal(xs, datax));
auto ly = mm->add_literal(migraphx::literal(ys, datay));
mm->add_parameter("x1", xs);
mm->add_parameter("x2", xs);
auto y2 = mm->add_parameter("y2", ys);
auto m = mm->add_instruction(migraphx::make_op("mul"), y2, ly);
mm->add_return({m});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_inline());
}
TEST_CASE(inline_tuple_true_test)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
auto cond = mm->add_literal(migraphx::literal(sc, {1}));
migraphx::shape sd{migraphx::shape::float_type, {1}};
auto l1 = mm->add_literal(migraphx::literal(sd, {1}));
auto l2 = mm->add_literal(migraphx::literal(sd, {2}));
auto l3 = mm->add_literal(migraphx::literal(sd, {3}));
migraphx::shape sx{migraphx::shape::float_type, {1, 4}};
migraphx::shape sy{migraphx::shape::float_type, {3, 4}};
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sy);
auto* then_mod = p.create_module("If_6_if");
auto m1 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l1);
auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1);
auto m2 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l2);
auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2);
then_mod->add_return({add0, mul0});
auto* else_mod = p.create_module("If_6_else");
auto me1 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l3);
auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1);
auto me2 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l3);
auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2);
else_mod->add_return({mul1, add1});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r0, r1});
return p;
};
auto create_inline = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {1}};
auto l1 = mm->add_literal(migraphx::literal(sd, {1}));
auto l2 = mm->add_literal(migraphx::literal(sd, {2}));
mm->add_literal(migraphx::literal(sd, {3}));
migraphx::shape sx{migraphx::shape::float_type, {1, 4}};
migraphx::shape sy{migraphx::shape::float_type, {3, 4}};
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sy);
auto m1 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l1);
auto add = mm->add_instruction(migraphx::make_op("add"), x, m1);
auto m2 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l2);
auto mul = mm->add_instruction(migraphx::make_op("mul"), y, m2);
mm->add_return({add, mul});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_inline());
}
TEST_CASE(inline_tuple_false_test)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
auto cond = mm->add_literal(migraphx::literal(sc, {0}));
migraphx::shape sd{migraphx::shape::float_type, {1}};
auto l1 = mm->add_literal(migraphx::literal(sd, {1}));
auto l2 = mm->add_literal(migraphx::literal(sd, {2}));
auto l3 = mm->add_literal(migraphx::literal(sd, {3}));
migraphx::shape sx{migraphx::shape::float_type, {1, 4}};
migraphx::shape sy{migraphx::shape::float_type, {3, 4}};
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sy);
auto* then_mod = p.create_module("If_6_if");
auto m1 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l1);
auto add0 = then_mod->add_instruction(migraphx::make_op("add"), x, m1);
auto m2 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l2);
auto mul0 = then_mod->add_instruction(migraphx::make_op("mul"), y, m2);
then_mod->add_return({add0, mul0});
auto* else_mod = p.create_module("If_6_else");
auto me1 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l3);
auto mul1 = else_mod->add_instruction(migraphx::make_op("mul"), x, me1);
auto me2 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l3);
auto add1 = else_mod->add_instruction(migraphx::make_op("add"), y, me2);
else_mod->add_return({mul1, add1});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r0, r1});
return p;
};
auto create_inline = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
migraphx::shape sd{migraphx::shape::float_type, {1}};
mm->add_literal(migraphx::literal(sd, {1}));
mm->add_literal(migraphx::literal(sd, {2}));
auto l3 = mm->add_literal(migraphx::literal(sd, {3}));
migraphx::shape sx{migraphx::shape::float_type, {1, 4}};
migraphx::shape sy{migraphx::shape::float_type, {3, 4}};
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sy);
auto m1 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l3);
auto mul = mm->add_instruction(migraphx::make_op("mul"), x, m1);
auto m2 =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), l3);
auto add = mm->add_instruction(migraphx::make_op("add"), y, m2);
mm->add_return({mul, add});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_inline());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/insert_pad.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
void run_pass(migraphx::module& m)
{
migraphx::run_passes(
m, {migraphx::normalize_ops{}, migraphx::insert_pad{}, migraphx::dead_code_elimination{}});
}
migraphx::instruction_ref
create_im2col(migraphx::instruction_ref& l_img, size_t channels, migraphx::module& m)
{
size_t f[2] = {1, 1};
std::vector<int32_t> weights(channels * f[0] * f[1]);
migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}};
auto l_weights = m.add_literal(migraphx::literal{s_weights, weights});
return m.add_instruction(
migraphx::make_op("im2col", {{"padding", {0, 0, 1, 1}}}), l_img, l_weights);
}
migraphx::instruction_ref
create_conv(migraphx::instruction_ref& l_img,
size_t channels,
migraphx::module& m,
migraphx::op::padding_mode_t padding_mode = migraphx::op::padding_mode_t::default_)
{
migraphx::shape s_weights{migraphx::shape::int32_type, {4, channels, 3, 3}};
std::vector<int32_t> weights(4 * channels * 3 * 3);
auto l_weights = m.add_literal(migraphx::literal{s_weights, weights});
migraphx::op::convolution op;
op.padding_mode = padding_mode;
op.padding = {0, 0, 1, 1};
return m.add_instruction(op, l_img, l_weights);
}
TEST_CASE(rewrite_pad)
{
migraphx::module m;
size_t img_dim[2] = {2, 2};
size_t channels = 1;
std::vector<int32_t> input(channels * img_dim[0] * img_dim[1]);
std::iota(input.begin(), input.end(), 0);
migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}};
auto l_img = m.add_literal(migraphx::literal{s_img, input});
auto l0 = create_im2col(l_img, channels, m);
auto l1 = create_conv(l_img, channels, m);
auto l2 = m.add_instruction(
migraphx::make_op("pooling", {{"mode", "max"}, {"padding", {0, 0, 1, 1}}}), l_img);
m.add_instruction(migraphx::make_op("identity"), l0, l1, l2);
run_pass(m);
EXPECT(std::any_of(
m.begin(), m.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; }));
}
TEST_CASE(rewrite_pad_symmetric)
{
migraphx::module m;
size_t img_dim[2] = {2, 2};
size_t channels = 1;
std::vector<int32_t> input(channels * img_dim[0] * img_dim[1]);
std::iota(input.begin(), input.end(), 0);
migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}};
auto l_img = m.add_literal(migraphx::literal{s_img, input});
m.add_instruction(migraphx::make_op("pooling", {{"mode", "max"}, {"padding", {1, 1, 1, 1}}}),
l_img);
run_pass(m);
EXPECT(std::none_of(
m.begin(), m.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include "test.hpp"
int main() {}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -7,591 +7,500 @@ namespace match = migraphx::match;
MIGRAPHX_PRED_MATCHER(throws, migraphx::instruction_ref) { MIGRAPHX_THROW("Matcher throws"); }
template <class M>
migraphx::match::matcher_result find_match(migraphx::module& modl, M&& m)
{
migraphx::match::matcher_result result;
for(auto ins : migraphx::iterator_for(modl))
{
result = migraphx::match::match_instruction(modl, ins, m);
if(result.result != modl.end())
return result;
}
return result;
}
void match1()
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l = mm->add_literal(1);
auto m = match::standard_shape();
auto r = find_match(*mm, m);
migraphx::module mm;
auto l = mm.add_literal(1);
auto m = match::standard_shape();
auto r = find_match(mm, m);
EXPECT(bool{r.result == l});
}
TEST_CASE(match_name1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum");
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_name2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("min");
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_name3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::standard_shape());
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_arg1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::arg(0)(match::name("@literal")), match::standard_shape());
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_arg2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::arg(0)(match::name("sum")), match::standard_shape());
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_arg3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::arg(1)(match::name("@literal")), match::standard_shape());
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_arg4)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
auto pass = mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
auto pass = mm.add_instruction(pass_op{}, sum);
auto m = match::name("pass")(match::arg(0)(match::name("sum")), match::standard_shape());
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == pass});
}
TEST_CASE(match_arg5)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("pass")(match::arg(1)(match::name("sum")), match::standard_shape());
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_arg6)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::arg(0)(match::name("@literal")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_arg7)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::arg(0)(match::name("@literal")),
match::arg(1)(match::name("@literal")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_arg8)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")),
match::arg(1)(match::name("@literal"))),
match::standard_shape());
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_nargs1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::nargs(2));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_nargs2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::nargs(2), match::standard_shape());
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_nargs3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::nargs(2)));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_args1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::args(match::name("@literal"), match::name("@literal")),
match::standard_shape());
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_args2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::args(match::name("@literal"), match::name("sum")),
match::standard_shape());
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_args3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::args(match::name("@literal")), match::standard_shape());
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_args4)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")),
match::standard_shape());
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
}
TEST_CASE(match_args5)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")),
match::standard_shape());
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_args6)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
auto pass = mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
auto pass = mm.add_instruction(pass_op{}, sum);
auto m = match::name("pass")(match::args(match::name("sum")), match::standard_shape());
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == pass});
}
TEST_CASE(match_args7)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
auto pass = mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
auto pass = mm.add_instruction(pass_op{}, sum);
auto m = match::name("pass")(match::args(match::name("sum")(match::args(
match::name("@literal"), match::name("@literal")))),
match::standard_shape());
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == pass});
}
TEST_CASE(match_either_args1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m =
match::name("sum")(match::either_arg(0, 1)(match::name("sum"), match::name("@literal")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
}
TEST_CASE(match_either_args2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m =
match::name("sum")(match::either_arg(0, 1)(match::name("@literal"), match::name("sum")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
}
TEST_CASE(match_either_args3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m =
match::name("sum")(match::either_arg(0, 1)(match::name("pass"), match::name("@literal")));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_either_args_any1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m =
match::name("sum")(match::either_arg(0, 1)(match::any().bind("x"), match::any().bind("y")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(
match::either_arg(0, 1)(match::any().bind("x"), match::name("@literal").bind("y")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(
match::either_arg(0, 1)(match::name("@literal").bind("x"), match::any().bind("y")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any4)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(
match::either_arg(0, 1)(match::name("sum").bind("x"), match::any().bind("y")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_either_args_any5)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(
match::either_arg(0, 1)(match::any().bind("x"), match::name("sum").bind("y")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
}
TEST_CASE(match_all_of1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")),
match::arg(1)(match::name("@literal"))));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_all_of2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::all_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal"))));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_all_of3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::all_of(
match::arg(0)(match::name("@literal")), match::arg(1)(match::name("@literal")))));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_lazy_any_of)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
mm->add_instruction(pass_op{}, one);
migraphx::module mm;
auto one = mm.add_literal(1);
mm.add_instruction(pass_op{}, one);
auto m = match::any_of(match::any(), throws());
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == one});
}
TEST_CASE(match_lazy_all_of)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
mm->add_instruction(pass_op{}, one);
migraphx::module mm;
auto one = mm.add_literal(1);
mm.add_instruction(pass_op{}, one);
auto m = match::all_of(match::none(), throws());
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_lazy_none_of)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
mm->add_instruction(pass_op{}, one);
migraphx::module mm;
auto one = mm.add_literal(1);
mm.add_instruction(pass_op{}, one);
auto m = match::none_of(match::any(), throws());
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_any_of1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal"))));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_any_of2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum"))));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_any_of_lazy1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::args(match::any(), match::any()).bind("x"),
match::args(match::name("sum"), match::name("sum")).bind("y")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x"));
EXPECT(bool{r.instructions["x"] == sum});
......@@ -600,17 +509,15 @@ TEST_CASE(match_any_of_lazy1)
TEST_CASE(match_any_of_lazy2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::args(match::name("@literal"), match::name("@literal")).bind("x"),
match::args(match::any(), match::any()).bind("y")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x"));
EXPECT(bool{r.instructions["x"] == sum});
......@@ -619,17 +526,15 @@ TEST_CASE(match_any_of_lazy2)
TEST_CASE(match_any_of_lazy3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::any_of(match::args(match::any(), match::any()).bind("x"),
match::args(match::name("@literal"), match::name("@literal")).bind("y")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x"));
EXPECT(bool{r.instructions["x"] == sum});
......@@ -638,17 +543,15 @@ TEST_CASE(match_any_of_lazy3)
TEST_CASE(match_any_of_lazy4)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::any_of(
match::args(match::name("@literal").bind("x1"), match::name("@literal").bind("y1")),
match::args(match::any().bind("x2"), match::any().bind("y2"))));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x1"));
EXPECT(migraphx::contains(r.instructions, "y1"));
......@@ -660,17 +563,15 @@ TEST_CASE(match_any_of_lazy4)
TEST_CASE(match_any_of_lazy5)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::any_of(
match::args(match::any().bind("x1"), match::any().bind("y1")),
match::args(match::name("@literal").bind("x2"), match::name("@literal").bind("y2"))));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
EXPECT(migraphx::contains(r.instructions, "x1"));
EXPECT(migraphx::contains(r.instructions, "y1"));
......@@ -682,194 +583,170 @@ TEST_CASE(match_any_of_lazy5)
TEST_CASE(match_none_of1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(
match::none_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum"))));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_none_of2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::none_of(match::arg(0)(match::name("@literal")),
match::arg(1)(match::name("@literal"))));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_output1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto minus = mm->add_instruction(minus_op{}, two, one);
auto sum = mm->add_instruction(sum_op{}, minus, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto minus = mm.add_instruction(minus_op{}, two, one);
auto sum = mm.add_instruction(sum_op{}, minus, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::output(match::name("sum")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_output2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto minus = mm->add_instruction(minus_op{}, two, one);
auto sum = mm->add_instruction(sum_op{}, minus, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto minus = mm.add_instruction(minus_op{}, two, one);
auto sum = mm.add_instruction(sum_op{}, minus, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("@literal")(match::output(match::name("sum")));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_skip_output1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto minus = mm->add_instruction(minus_op{}, two, one);
auto sum = mm->add_instruction(sum_op{}, minus, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto minus = mm.add_instruction(minus_op{}, two, one);
auto sum = mm.add_instruction(sum_op{}, minus, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto minus = mm->add_instruction(minus_op{}, two, one);
auto minus_pass = mm->add_instruction(pass_op{}, minus);
auto sum = mm->add_instruction(sum_op{}, minus_pass, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto minus = mm.add_instruction(minus_op{}, two, one);
auto minus_pass = mm.add_instruction(pass_op{}, minus);
auto sum = mm.add_instruction(sum_op{}, minus_pass, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto minus = mm->add_instruction(minus_op{}, two, one);
auto minus_pass1 = mm->add_instruction(pass_op{}, minus);
auto minus_pass2 = mm->add_instruction(pass_op{}, minus_pass1);
auto minus_pass3 = mm->add_instruction(pass_op{}, minus_pass2);
auto sum = mm->add_instruction(sum_op{}, minus_pass3, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto minus = mm.add_instruction(minus_op{}, two, one);
auto minus_pass1 = mm.add_instruction(pass_op{}, minus);
auto minus_pass2 = mm.add_instruction(pass_op{}, minus_pass1);
auto minus_pass3 = mm.add_instruction(pass_op{}, minus_pass2);
auto sum = mm.add_instruction(sum_op{}, minus_pass3, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output4)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto pass = mm->add_instruction(pass_op{}, one);
auto sum = mm->add_instruction(sum_op{}, pass, two);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto pass = mm.add_instruction(pass_op{}, one);
auto sum = mm.add_instruction(sum_op{}, pass, two);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == two});
}
TEST_CASE(match_skip_output5)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto pass = mm->add_instruction(pass_op{}, one);
auto sum1 = mm->add_instruction(sum_op{}, pass, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, one);
auto sum3 = mm->add_instruction(sum_op{}, sum2, two);
mm->add_instruction(pass_op{}, sum3);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto pass = mm.add_instruction(pass_op{}, one);
auto sum1 = mm.add_instruction(sum_op{}, pass, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, one);
auto sum3 = mm.add_instruction(sum_op{}, sum2, two);
mm.add_instruction(pass_op{}, sum3);
auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_skip_output6)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto minus = mm->add_instruction(minus_op{}, two, one);
auto sum1 = mm->add_instruction(sum_op{}, minus, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, one);
auto sum3 = mm->add_instruction(sum_op{}, sum2, two);
mm->add_instruction(pass_op{}, sum3);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto minus = mm.add_instruction(minus_op{}, two, one);
auto sum1 = mm.add_instruction(sum_op{}, minus, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, one);
auto sum3 = mm.add_instruction(sum_op{}, sum2, two);
mm.add_instruction(pass_op{}, sum3);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output7)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto minus1 = mm->add_instruction(minus_op{}, two, one);
auto minus2 = mm->add_instruction(minus_op{}, two, minus1);
auto sum = mm->add_instruction(sum_op{}, one, minus2);
mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto minus1 = mm.add_instruction(minus_op{}, two, one);
auto minus2 = mm.add_instruction(minus_op{}, two, minus1);
auto sum = mm.add_instruction(sum_op{}, one, minus2);
mm.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("minus")));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == minus1});
}
TEST_CASE(match_bind1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
auto pass = mm->add_instruction(pass_op{}, sum);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
auto pass = mm.add_instruction(pass_op{}, sum);
auto m = match::name("pass")(
match::args(match::name("sum")(match::args(match::name("@literal").bind("one"),
match::name("@literal").bind("two")))
.bind("sum")),
match::standard_shape())
.bind("pass");
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.instructions.at("one") == one});
EXPECT(bool{r.instructions.at("two") == two});
EXPECT(bool{r.instructions.at("sum") == sum});
......@@ -877,265 +754,280 @@ TEST_CASE(match_bind1)
EXPECT(bool{r.result == pass});
}
TEST_CASE(match_has_value1)
TEST_CASE(match_bind_modules1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto* child = p.create_module("child");
auto two = child->add_literal(2);
auto sum = child->add_instruction(sum_op{}, one, two);
child->add_instruction(pass_op{}, sum);
mm->add_instruction(mod_pass_op{}, {one}, {child});
auto m = match::name("pass")(
match::args(match::name("sum")(match::args(match::name("@literal").bind("one"),
match::name("@literal").bind("two")))
.bind("sum")),
match::standard_shape())
.bind("pass");
auto r = find_match(*child, m);
EXPECT(not migraphx::contains(r.instructions, "one"));
EXPECT(not migraphx::contains(r.instructions, "two"));
EXPECT(not migraphx::contains(r.instructions, "sum"));
EXPECT(not migraphx::contains(r.instructions, "pass"));
EXPECT(bool{r.result == child->end()});
}
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
TEST_CASE(match_bind_modules2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto* child = p.create_module("child");
auto two = child->add_literal(2);
auto sum = child->add_instruction(sum_op{}, one, two);
auto pass = child->add_instruction(pass_op{}, sum);
mm->add_instruction(mod_pass_op{}, {one}, {child});
auto m = match::name("pass")(
match::args(match::name("sum")(match::args(match::name("@literal"),
match::name("@literal").bind("two")))
.bind("sum")),
match::standard_shape())
.bind("pass");
auto r = find_match(*child, m);
EXPECT(bool{r.instructions.at("two") == two});
EXPECT(bool{r.instructions.at("sum") == sum});
EXPECT(bool{r.instructions.at("pass") == pass});
EXPECT(bool{r.result == pass});
}
TEST_CASE(match_has_value1)
{
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m = match::has_value(1);
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == one});
}
TEST_CASE(match_has_value2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m = match::has_value(2);
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == two});
}
TEST_CASE(match_has_value3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(match::args(match::has_value(1), match::has_value(2)));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1});
}
TEST_CASE(match_has_value4)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m = match::has_value(3);
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_has_value5)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(match::args(match::has_value(1), match::has_value(3)));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_has_value6)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, two);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, two);
mm.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(match::args(match::has_value(2), match::has_value(1)));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_tree1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto three = mm->add_literal(3);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto three = mm.add_literal(3);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, three);
mm.add_instruction(pass_op{}, sum2);
auto m = match::tree(
match::name("sum"), match::has_value(1), match::has_value(2), match::has_value(3));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
}
TEST_CASE(match_tree2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto three = mm->add_literal(3);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto three = mm.add_literal(3);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, three);
mm.add_instruction(pass_op{}, sum2);
auto m = match::tree(
match::name("sum"), match::has_value(2), match::has_value(1), match::has_value(3));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_tree3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto three = mm->add_literal(3);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, three, sum1);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto three = mm.add_literal(3);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, three, sum1);
mm.add_instruction(pass_op{}, sum2);
auto m = match::tree(
match::name("sum"), match::has_value(3), match::has_value(1), match::has_value(2));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
}
TEST_CASE(match_tree4)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto three = mm->add_literal(3);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto three = mm.add_literal(3);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, three);
mm.add_instruction(pass_op{}, sum2);
auto m = match::tree(match::name("sum"),
match::has_value(1),
match::has_value(2),
match::has_value(3),
match::has_value(4));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_tree5)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto three = mm->add_literal(3);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto three = mm.add_literal(3);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, three);
mm.add_instruction(pass_op{}, sum2);
auto m = match::tree(match::name("sum"), match::has_value(2), match::has_value(3));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_tree6)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto three = mm->add_literal(3);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto three = mm.add_literal(3);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, three);
mm.add_instruction(pass_op{}, sum2);
auto m = match::tree(match::name("sum"), match::has_value(1), match::has_value(3));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
TEST_CASE(match_unordered_tree1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto three = mm->add_literal(3);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto three = mm.add_literal(3);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, three);
mm.add_instruction(pass_op{}, sum2);
auto m = match::unordered_tree(
match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
}
TEST_CASE(match_unordered_tree2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto three = mm->add_literal(3);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, three, sum1);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto three = mm.add_literal(3);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, three, sum1);
mm.add_instruction(pass_op{}, sum2);
auto m = match::unordered_tree(
match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
}
TEST_CASE(match_unordered_tree3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto three = mm->add_literal(3);
auto sum1 = mm->add_instruction(sum_op{}, two, one);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto three = mm.add_literal(3);
auto sum1 = mm.add_instruction(sum_op{}, two, one);
auto sum2 = mm.add_instruction(sum_op{}, sum1, three);
mm.add_instruction(pass_op{}, sum2);
auto m = match::unordered_tree(
match::name("sum"), match::has_value(3), match::has_value(2), match::has_value(1));
auto r = find_match(*mm, m);
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
}
TEST_CASE(match_unordered_tree4)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto three = mm->add_literal(3);
auto sum1 = mm->add_instruction(sum_op{}, one, two);
auto sum2 = mm->add_instruction(sum_op{}, sum1, three);
mm->add_instruction(pass_op{}, sum2);
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto three = mm.add_literal(3);
auto sum1 = mm.add_instruction(sum_op{}, one, two);
auto sum2 = mm.add_instruction(sum_op{}, sum1, three);
mm.add_instruction(pass_op{}, sum2);
auto m = match::unordered_tree(
match::name("sum"), match::has_value(4), match::has_value(2), match::has_value(1));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == mm->end()});
auto r = find_match(mm, m);
EXPECT(bool{r.result == mm.end()});
}
struct match_find_sum
......@@ -1163,14 +1055,12 @@ struct match_find_literal
TEST_CASE(match_finder)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(pass_op{}, sum);
match::find_matches(*mm, match_find_sum{sum}, match_find_literal{sum});
migraphx::module mm;
auto one = mm.add_literal(1);
auto two = mm.add_literal(2);
auto sum = mm.add_instruction(sum_op{}, one, two);
mm.add_instruction(pass_op{}, sum);
match::find_matches(mm, match_find_sum{sum}, match_find_literal{sum});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -46,6 +46,39 @@ bool no_allocate(const migraphx::module& m)
return std::none_of(m.begin(), m.end(), [](auto&& ins) { return ins.name() == "allocate"; });
}
bool is_overlap(std::pair<std::size_t, std::size_t> x, std::pair<std::size_t, std::size_t> y)
{
return std::max(x.first, y.first) < std::min(x.second, y.second);
}
std::pair<std::size_t, std::size_t> get_load_interval(migraphx::instruction_ref a)
{
auto v = a->get_operator().to_value();
auto offset = v.at("offset").to<std::size_t>();
auto s = migraphx::from_value<migraphx::shape>(v.at("shape"));
return {offset, offset + s.bytes()};
}
bool is_overlap_load(migraphx::instruction_ref a, migraphx::instruction_ref b)
{
return is_overlap(get_load_interval(a), get_load_interval(b));
}
bool is_disjoint(const std::vector<migraphx::instruction_ref>& inss)
{
for(auto ins1 : inss)
{
for(auto ins2 : inss)
{
if(ins1 == ins2)
continue;
if(is_overlap_load(ins1, ins2))
return false;
}
}
return true;
}
TEST_CASE(test1)
{
migraphx::module m;
......@@ -57,6 +90,7 @@ TEST_CASE(test1)
run_pass(m);
CHECK(m.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(m));
CHECK(is_disjoint({a1, a2}));
}
TEST_CASE(test2)
......@@ -680,6 +714,3047 @@ TEST_CASE(test39)
CHECK(no_allocate(*else_mod));
}
// NOLINTNEXTLINE
TEST_CASE(rnn_dom)
{
migraphx::module m;
auto mx0 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 10}});
auto mx1 = m.add_instruction(pass_op{});
auto mr = m.add_parameter("r", migraphx::shape{migraphx::shape::float_type, {1, 15, 5}});
auto mx2 = m.add_instruction(pass_op{}, mr);
auto mx3 = m.add_instruction(pass_op{}, mx2);
auto mx4 = m.add_instruction(pass_op{}, mx3);
m.add_instruction(pass_op{});
auto mx6 = m.add_instruction(pass_op{}, mx0, mx1, mx4);
m.add_instruction(pass_op{});
auto mx8 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 15}});
m.add_instruction(pass_op{}, mx8, mx1, mx0, mx6);
auto mseq = m.add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {3, 2, 8}});
auto mx10 = m.add_instruction(pass_op{}, mseq);
auto mx11 = m.add_instruction(pass_op{}, mx10);
auto mw = m.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 15, 8}});
auto mx12 = m.add_instruction(pass_op{}, mw);
auto mx13 = m.add_instruction(pass_op{}, mx12);
m.add_instruction(pass_op{});
auto mx15 = m.add_instruction(pass_op{}, mx8, mx11, mx13);
m.add_instruction(pass_op{}, mx15, mx1, mx0, mx6);
m.add_instruction(pass_op{});
auto mx18 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(pass_op{}, mx18, mx6, mx15, mx0, mx1, mx8);
auto mx20 = m.add_instruction(pass_op{}, mx6);
m.add_instruction(pass_op{}, mx20, mx8, mx15, mx18);
auto mx22 = m.add_instruction(pass_op{}, mx15);
m.add_instruction(pass_op{}, mx22, mx1, mx0, mx20, mx6, mx18);
m.add_instruction(pass_op{});
m.add_instruction(pass_op{});
m.add_instruction(pass_op{});
auto mx27 = m.add_instruction(pass_op{}, mx18, mx22, mx20);
m.add_instruction(pass_op{}, mx27, mx15, mx8, mx6, mx20, mx1, mx22, mx0);
m.add_instruction(pass_op{});
auto mx30 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(pass_op{}, mx30, mx20, mx22, mx1, mx15, mx8, mx6, mx27, mx0, mx18);
auto mx32 = m.add_instruction(pass_op{}, mx15);
m.add_instruction(pass_op{}, mx32, mx20, mx30, mx0, mx18, mx1, mx27, mx6);
m.add_instruction(pass_op{});
m.add_instruction(pass_op{});
auto mx36 = m.add_instruction(pass_op{}, mx30, mx32);
m.add_instruction(pass_op{}, mx36, mx32, mx0, mx27, mx8, mx1, mx15, mx6, mx20, mx22, mx18);
auto mx38 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(pass_op{}, mx38, mx32, mx0, mx27, mx8, mx1, mx15, mx6, mx20, mx22, mx18);
auto mx40 = m.add_instruction(pass_op{}, mx38, mx36);
m.add_instruction(pass_op{}, mx40, mx32, mx0, mx27, mx8, mx1, mx15, mx6, mx20, mx22, mx18);
m.add_instruction(pass_op{});
auto mx43 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(pass_op{}, mx43, mx15, mx32, mx27, mx30, mx18, mx8, mx40, mx36, mx22, mx38);
auto mx45 = m.add_instruction(pass_op{}, mx6);
m.add_instruction(pass_op{}, mx45, mx32, mx27, mx30, mx18, mx40, mx36, mx22, mx8, mx15, mx38);
auto mx47 = m.add_instruction(pass_op{}, mx15);
m.add_instruction(
pass_op{}, mx47, mx30, mx18, mx43, mx6, mx1, mx45, mx0, mx27, mx36, mx20, mx40, mx38);
m.add_instruction(pass_op{});
m.add_instruction(pass_op{});
auto mx51 = m.add_instruction(pass_op{}, mx43, mx47, mx45);
m.add_instruction(
pass_op{}, mx51, mx15, mx47, mx32, mx27, mx30, mx18, mx8, mx36, mx22, mx40, mx38);
auto mx53 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(
pass_op{}, mx53, mx15, mx47, mx32, mx27, mx30, mx18, mx8, mx36, mx22, mx40, mx38);
auto mx55 = m.add_instruction(pass_op{}, mx53, mx51, mx1);
m.add_instruction(
pass_op{}, mx55, mx15, mx47, mx32, mx27, mx30, mx18, mx8, mx36, mx22, mx40, mx38);
auto mx57 = m.add_instruction(pass_op{}, mx3);
m.add_instruction(pass_op{});
auto mx59 = m.add_instruction(pass_op{}, mx40, mx55, mx57, mx40);
m.add_instruction(
pass_op{}, mx59, mx15, mx8, mx38, mx18, mx30, mx27, mx47, mx32, mx40, mx36, mx22);
auto mx61 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(pass_op{},
mx61,
mx30,
mx15,
mx1,
mx51,
mx20,
mx59,
mx32,
mx45,
mx22,
mx8,
mx47,
mx40,
mx53,
mx6,
mx55,
mx0,
mx43,
mx38,
mx36);
m.add_instruction(pass_op{});
auto mx64 = m.add_instruction(pass_op{}, mx61, mx27, mx1);
m.add_instruction(pass_op{},
mx64,
mx30,
mx15,
mx1,
mx51,
mx20,
mx59,
mx32,
mx45,
mx22,
mx8,
mx47,
mx40,
mx53,
mx6,
mx55,
mx0,
mx43,
mx38,
mx36);
m.add_instruction(pass_op{});
auto mx67 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(pass_op{},
mx67,
mx18,
mx6,
mx1,
mx51,
mx20,
mx59,
mx27,
mx55,
mx43,
mx38,
mx0,
mx61,
mx45,
mx36,
mx40,
mx53,
mx64,
mx30);
auto mx69 = m.add_instruction(pass_op{});
m.add_instruction(pass_op{},
mx69,
mx18,
mx6,
mx1,
mx51,
mx20,
mx59,
mx27,
mx55,
mx43,
mx38,
mx0,
mx61,
mx45,
mx36,
mx40,
mx53,
mx64,
mx30);
m.add_instruction(pass_op{});
m.add_instruction(pass_op{});
auto mx73 = m.add_instruction(pass_op{}, mx67, mx69, mx27);
m.add_instruction(pass_op{},
mx73,
mx18,
mx6,
mx1,
mx51,
mx20,
mx59,
mx27,
mx55,
mx43,
mx38,
mx0,
mx61,
mx45,
mx36,
mx40,
mx53,
mx64,
mx30);
m.add_instruction(pass_op{});
auto mx76 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(pass_op{},
mx76,
mx64,
mx30,
mx18,
mx40,
mx8,
mx61,
mx38,
mx69,
mx67,
mx73,
mx27,
mx47,
mx32,
mx36,
mx15,
mx22);
m.add_instruction(pass_op{});
auto mx79 = m.add_instruction(pass_op{}, mx76, mx59);
m.add_instruction(pass_op{},
mx79,
mx64,
mx30,
mx18,
mx40,
mx8,
mx61,
mx38,
mx69,
mx67,
mx73,
mx27,
mx47,
mx32,
mx36,
mx15,
mx22);
auto mx81 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(pass_op{},
mx81,
mx36,
mx32,
mx27,
mx47,
mx18,
mx30,
mx73,
mx67,
mx22,
mx15,
mx61,
mx8,
mx64,
mx40,
mx69,
mx38);
m.add_instruction(pass_op{});
m.add_instruction(pass_op{});
auto mx85 = m.add_instruction(pass_op{}, mx81, mx73, mx79, mx64);
m.add_instruction(pass_op{},
mx85,
mx36,
mx32,
mx27,
mx47,
mx18,
mx30,
mx73,
mx67,
mx22,
mx15,
mx61,
mx8,
mx64,
mx40,
mx69,
mx38);
m.add_instruction(pass_op{});
auto mx88 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 10}});
m.add_instruction(pass_op{},
mx88,
mx36,
mx32,
mx27,
mx47,
mx18,
mx30,
mx73,
mx67,
mx22,
mx15,
mx61,
mx8,
mx64,
mx40,
mx69,
mx38);
auto mx90 = m.add_instruction(pass_op{}, mx88, mx85, mx4);
m.add_instruction(pass_op{},
mx90,
mx36,
mx32,
mx27,
mx47,
mx18,
mx30,
mx73,
mx67,
mx22,
mx15,
mx61,
mx8,
mx64,
mx40,
mx69,
mx38);
m.add_instruction(pass_op{});
auto mx93 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 15}});
m.add_instruction(pass_op{},
mx93,
mx51,
mx88,
mx20,
mx64,
mx43,
mx61,
mx53,
mx81,
mx47,
mx6,
mx45,
mx0,
mx55,
mx18,
mx76,
mx1,
mx79,
mx85,
mx90,
mx8,
mx69,
mx67,
mx73,
mx32,
mx59,
mx22,
mx15,
mx27);
auto mx95 = m.add_instruction(pass_op{}, mseq);
auto mx96 = m.add_instruction(pass_op{}, mx95);
m.add_instruction(pass_op{});
auto mx98 = m.add_instruction(pass_op{}, mx93, mx96, mx13);
m.add_instruction(pass_op{},
mx98,
mx51,
mx88,
mx20,
mx64,
mx43,
mx61,
mx53,
mx81,
mx47,
mx6,
mx45,
mx0,
mx55,
mx18,
mx76,
mx1,
mx79,
mx85,
mx90,
mx8,
mx69,
mx67,
mx73,
mx32,
mx59,
mx22,
mx15,
mx27);
m.add_instruction(pass_op{});
auto mx101 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(pass_op{},
mx101,
mx43,
mx40,
mx53,
mx59,
mx51,
mx6,
mx61,
mx81,
mx38,
mx45,
mx20,
mx0,
mx76,
mx55,
mx18,
mx85,
mx1,
mx93,
mx79,
mx90,
mx27,
mx88,
mx64,
mx30,
mx98,
mx36);
auto mx103 = m.add_instruction(pass_op{}, mx90);
m.add_instruction(pass_op{},
mx103,
mx64,
mx101,
mx15,
mx67,
mx73,
mx18,
mx40,
mx8,
mx47,
mx98,
mx27,
mx32,
mx61,
mx22,
mx93,
mx69,
mx36,
mx38,
mx30);
auto mx105 = m.add_instruction(pass_op{}, mx98);
m.add_instruction(pass_op{},
mx105,
mx43,
mx88,
mx53,
mx64,
mx59,
mx6,
mx76,
mx61,
mx81,
mx47,
mx103,
mx22,
mx45,
mx0,
mx55,
mx18,
mx85,
mx51,
mx20,
mx1,
mx79,
mx90,
mx8,
mx101,
mx15,
mx69,
mx67,
mx73,
mx32,
mx27);
m.add_instruction(pass_op{});
m.add_instruction(pass_op{});
m.add_instruction(pass_op{});
auto mx110 = m.add_instruction(pass_op{}, mx101, mx105, mx103);
m.add_instruction(pass_op{},
mx110,
mx88,
mx40,
mx93,
mx59,
mx43,
mx61,
mx53,
mx81,
mx103,
mx6,
mx45,
mx0,
mx55,
mx18,
mx64,
mx20,
mx76,
mx1,
mx79,
mx38,
mx85,
mx90,
mx27,
mx30,
mx105,
mx98,
mx51,
mx36);
m.add_instruction(pass_op{});
auto mx113 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(pass_op{},
mx113,
mx59,
mx20,
mx51,
mx1,
mx79,
mx90,
mx55,
mx85,
mx76,
mx81,
mx47,
mx6,
mx38,
mx88,
mx43,
mx40,
mx0,
mx45,
mx53,
mx93,
mx8,
mx101,
mx15,
mx69,
mx67,
mx73,
mx32,
mx110,
mx22,
mx103,
mx30,
mx36,
mx98,
mx105);
auto mx115 = m.add_instruction(pass_op{}, mx98);
m.add_instruction(pass_op{},
mx115,
mx59,
mx20,
mx51,
mx1,
mx79,
mx90,
mx55,
mx18,
mx85,
mx76,
mx61,
mx81,
mx47,
mx6,
mx88,
mx43,
mx0,
mx45,
mx53,
mx64,
mx8,
mx101,
mx15,
mx69,
mx67,
mx73,
mx113,
mx32,
mx110,
mx22,
mx103,
mx27);
m.add_instruction(pass_op{});
m.add_instruction(pass_op{});
auto mx119 = m.add_instruction(pass_op{}, mx113, mx115);
m.add_instruction(pass_op{},
mx119,
mx59,
mx20,
mx51,
mx1,
mx79,
mx90,
mx55,
mx85,
mx76,
mx81,
mx47,
mx6,
mx38,
mx88,
mx43,
mx40,
mx0,
mx45,
mx53,
mx93,
mx8,
mx101,
mx15,
mx69,
mx67,
mx73,
mx32,
mx110,
mx22,
mx103,
mx30,
mx36,
mx115,
mx98,
mx105);
auto mx121 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(pass_op{},
mx121,
mx59,
mx20,
mx51,
mx1,
mx79,
mx90,
mx55,
mx85,
mx76,
mx81,
mx47,
mx6,
mx38,
mx88,
mx43,
mx40,
mx0,
mx45,
mx53,
mx93,
mx8,
mx101,
mx15,
mx69,
mx67,
mx73,
mx32,
mx110,
mx22,
mx103,
mx30,
mx36,
mx115,
mx98,
mx105);
auto mx123 = m.add_instruction(pass_op{}, mx121, mx119);
m.add_instruction(pass_op{},
mx123,
mx59,
mx20,
mx51,
mx1,
mx79,
mx90,
mx55,
mx85,
mx76,
mx81,
mx47,
mx6,
mx38,
mx88,
mx43,
mx40,
mx0,
mx45,
mx53,
mx93,
mx8,
mx101,
mx15,
mx69,
mx67,
mx73,
mx32,
mx110,
mx22,
mx103,
mx30,
mx36,
mx115,
mx98,
mx105);
m.add_instruction(pass_op{});
auto mx126 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(pass_op{},
mx126,
mx115,
mx113,
mx8,
mx67,
mx61,
mx73,
mx18,
mx123,
mx119,
mx32,
mx15,
mx36,
mx110,
mx27,
mx101,
mx22,
mx98,
mx47,
mx40,
mx93,
mx38,
mx69,
mx121,
mx64,
mx30,
mx105);
auto mx128 = m.add_instruction(pass_op{}, mx90);
m.add_instruction(pass_op{},
mx128,
mx93,
mx98,
mx8,
mx67,
mx73,
mx18,
mx123,
mx61,
mx40,
mx47,
mx27,
mx32,
mx101,
mx22,
mx15,
mx110,
mx36,
mx119,
mx38,
mx64,
mx30,
mx69,
mx121,
mx113,
mx115,
mx105);
auto mx130 = m.add_instruction(pass_op{}, mx98);
m.add_instruction(pass_op{},
mx130,
mx119,
mx64,
mx22,
mx110,
mx126,
mx128,
mx121,
mx113,
mx67,
mx90,
mx69,
mx15,
mx20,
mx8,
mx27,
mx51,
mx85,
mx79,
mx123,
mx103,
mx18,
mx55,
mx32,
mx0,
mx45,
mx61,
mx53,
mx76,
mx6,
mx47,
mx59,
mx73,
mx81,
mx88,
mx1,
mx43,
mx101);
m.add_instruction(pass_op{});
m.add_instruction(pass_op{});
auto mx134 = m.add_instruction(pass_op{}, mx126, mx130, mx128);
m.add_instruction(pass_op{},
mx134,
mx130,
mx8,
mx67,
mx61,
mx73,
mx18,
mx123,
mx119,
mx32,
mx15,
mx36,
mx110,
mx27,
mx101,
mx22,
mx113,
mx115,
mx98,
mx47,
mx40,
mx93,
mx38,
mx69,
mx121,
mx64,
mx30,
mx105);
auto mx136 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(pass_op{},
mx136,
mx130,
mx8,
mx67,
mx61,
mx73,
mx18,
mx123,
mx119,
mx32,
mx15,
mx36,
mx110,
mx27,
mx101,
mx22,
mx113,
mx115,
mx98,
mx47,
mx40,
mx93,
mx38,
mx69,
mx121,
mx64,
mx30,
mx105);
auto mx138 = m.add_instruction(pass_op{}, mx136, mx134, mx85);
m.add_instruction(pass_op{},
mx138,
mx130,
mx8,
mx67,
mx61,
mx73,
mx18,
mx123,
mx119,
mx32,
mx15,
mx36,
mx110,
mx27,
mx101,
mx22,
mx113,
mx115,
mx98,
mx47,
mx40,
mx93,
mx38,
mx69,
mx121,
mx64,
mx30,
mx105);
m.add_instruction(pass_op{});
auto mx141 = m.add_instruction(pass_op{}, mx123, mx138, mx57, mx123);
m.add_instruction(pass_op{},
mx141,
mx113,
mx115,
mx130,
mx105,
mx38,
mx93,
mx61,
mx98,
mx27,
mx64,
mx30,
mx119,
mx121,
mx69,
mx8,
mx67,
mx40,
mx47,
mx32,
mx101,
mx22,
mx36,
mx110,
mx15,
mx73,
mx18,
mx123);
auto mx143 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(pass_op{},
mx143,
mx8,
mx73,
mx121,
mx67,
mx101,
mx110,
mx69,
mx15,
mx138,
mx88,
mx43,
mx79,
mx53,
mx61,
mx45,
mx18,
mx0,
mx6,
mx27,
mx22,
mx134,
mx32,
mx1,
mx119,
mx59,
mx85,
mx103,
mx126,
mx64,
mx128,
mx55,
mx76,
mx47,
mx81,
mx90,
mx136,
mx51,
mx141,
mx20,
mx113,
mx123);
m.add_instruction(pass_op{});
m.add_instruction(pass_op{});
auto mx147 = m.add_instruction(pass_op{}, mx143, mx69, mx110);
m.add_instruction(pass_op{},
mx147,
mx8,
mx73,
mx121,
mx67,
mx101,
mx110,
mx69,
mx15,
mx138,
mx88,
mx43,
mx79,
mx53,
mx61,
mx45,
mx18,
mx0,
mx6,
mx27,
mx22,
mx134,
mx32,
mx1,
mx119,
mx59,
mx85,
mx103,
mx126,
mx64,
mx128,
mx55,
mx76,
mx47,
mx81,
mx90,
mx136,
mx51,
mx141,
mx20,
mx113,
mx123);
m.add_instruction(pass_op{});
auto mx150 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(pass_op{},
mx150,
mx30,
mx121,
mx115,
mx98,
mx130,
mx85,
mx88,
mx90,
mx79,
mx1,
mx93,
mx64,
mx18,
mx53,
mx61,
mx38,
mx27,
mx147,
mx0,
mx6,
mx51,
mx40,
mx134,
mx43,
mx119,
mx59,
mx45,
mx76,
mx128,
mx81,
mx136,
mx55,
mx138,
mx123,
mx126,
mx141,
mx103,
mx20,
mx105,
mx113,
mx143,
mx36);
m.add_instruction(pass_op{});
m.add_instruction(pass_op{});
auto mx154 = m.add_instruction(pass_op{}, mx150, mx110, mx85);
m.add_instruction(pass_op{},
mx154,
mx30,
mx121,
mx115,
mx98,
mx130,
mx85,
mx88,
mx90,
mx79,
mx1,
mx93,
mx64,
mx18,
mx53,
mx61,
mx38,
mx27,
mx147,
mx0,
mx6,
mx51,
mx40,
mx134,
mx43,
mx119,
mx59,
mx45,
mx76,
mx128,
mx81,
mx136,
mx55,
mx138,
mx123,
mx126,
mx141,
mx103,
mx20,
mx105,
mx113,
mx143,
mx36);
m.add_instruction(pass_op{});
auto mx157 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(pass_op{},
mx157,
mx101,
mx8,
mx115,
mx130,
mx105,
mx38,
mx147,
mx93,
mx64,
mx61,
mx98,
mx40,
mx27,
mx121,
mx30,
mx154,
mx113,
mx73,
mx119,
mx36,
mx150,
mx69,
mx67,
mx47,
mx110,
mx32,
mx22,
mx15,
mx18,
mx123,
mx143);
m.add_instruction(pass_op{});
auto mx160 = m.add_instruction(pass_op{}, mx157, mx141);
m.add_instruction(pass_op{},
mx160,
mx101,
mx8,
mx115,
mx130,
mx105,
mx38,
mx147,
mx93,
mx64,
mx61,
mx98,
mx40,
mx27,
mx121,
mx30,
mx154,
mx113,
mx73,
mx119,
mx36,
mx150,
mx69,
mx67,
mx47,
mx110,
mx32,
mx22,
mx15,
mx18,
mx123,
mx143);
auto mx162 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(pass_op{},
mx162,
mx101,
mx8,
mx115,
mx130,
mx105,
mx38,
mx147,
mx93,
mx64,
mx61,
mx98,
mx40,
mx27,
mx121,
mx30,
mx154,
mx113,
mx73,
mx119,
mx36,
mx150,
mx69,
mx67,
mx47,
mx110,
mx32,
mx22,
mx15,
mx18,
mx123,
mx143);
m.add_instruction(pass_op{});
m.add_instruction(pass_op{});
auto mx166 = m.add_instruction(pass_op{}, mx162, mx147, mx160, mx154);
m.add_instruction(pass_op{},
mx166,
mx101,
mx8,
mx115,
mx130,
mx105,
mx38,
mx147,
mx93,
mx64,
mx61,
mx98,
mx40,
mx27,
mx121,
mx30,
mx154,
mx113,
mx73,
mx119,
mx36,
mx150,
mx69,
mx67,
mx47,
mx110,
mx32,
mx22,
mx15,
mx18,
mx123,
mx143);
m.add_instruction(pass_op{});
auto mx169 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 15}});
m.add_instruction(pass_op{},
mx169,
mx154,
mx90,
mx88,
mx79,
mx126,
mx15,
mx103,
mx22,
mx134,
mx166,
mx30,
mx73,
mx20,
mx128,
mx160,
mx8,
mx45,
mx0,
mx6,
mx157,
mx53,
mx136,
mx93,
mx47,
mx81,
mx141,
mx85,
mx110,
mx59,
mx1,
mx162,
mx101,
mx36,
mx38,
mx76,
mx143,
mx67,
mx147,
mx150,
mx138,
mx115,
mx105,
mx51,
mx69,
mx40,
mx32,
mx43,
mx55,
mx130,
mx98);
auto mx171 = m.add_instruction(pass_op{}, mseq);
auto mx172 = m.add_instruction(pass_op{}, mx171);
m.add_instruction(pass_op{});
auto mx174 = m.add_instruction(pass_op{}, mx169, mx172, mx13);
m.add_instruction(pass_op{},
mx174,
mx154,
mx90,
mx88,
mx79,
mx126,
mx15,
mx103,
mx22,
mx134,
mx166,
mx30,
mx73,
mx20,
mx128,
mx160,
mx8,
mx45,
mx0,
mx6,
mx157,
mx53,
mx136,
mx93,
mx47,
mx81,
mx141,
mx85,
mx110,
mx59,
mx1,
mx162,
mx101,
mx36,
mx38,
mx76,
mx143,
mx67,
mx147,
mx150,
mx138,
mx115,
mx105,
mx51,
mx69,
mx40,
mx32,
mx43,
mx55,
mx130,
mx98);
m.add_instruction(pass_op{});
auto mx177 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 10}});
m.add_instruction(pass_op{},
mx177,
mx101,
mx8,
mx115,
mx130,
mx105,
mx38,
mx147,
mx93,
mx64,
mx154,
mx61,
mx98,
mx40,
mx27,
mx174,
mx121,
mx30,
mx113,
mx73,
mx119,
mx36,
mx150,
mx69,
mx67,
mx47,
mx110,
mx32,
mx22,
mx169,
mx15,
mx18,
mx123,
mx143);
m.add_instruction(pass_op{});
auto mx180 = m.add_instruction(pass_op{}, mx177, mx166, mx4);
m.add_instruction(pass_op{},
mx180,
mx101,
mx8,
mx115,
mx130,
mx105,
mx38,
mx147,
mx93,
mx64,
mx154,
mx61,
mx98,
mx40,
mx27,
mx174,
mx121,
mx30,
mx113,
mx73,
mx119,
mx36,
mx150,
mx69,
mx67,
mx47,
mx110,
mx32,
mx22,
mx169,
mx15,
mx18,
mx123,
mx143);
m.add_instruction(pass_op{});
auto mx183 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(pass_op{},
mx183,
mx67,
mx90,
mx150,
mx138,
mx88,
mx79,
mx126,
mx15,
mx103,
mx22,
mx134,
mx180,
mx166,
mx174,
mx73,
mx20,
mx154,
mx32,
mx43,
mx55,
mx157,
mx18,
mx0,
mx113,
mx6,
mx76,
mx53,
mx61,
mx177,
mx136,
mx81,
mx141,
mx85,
mx110,
mx64,
mx45,
mx8,
mx169,
mx59,
mx1,
mx162,
mx101,
mx119,
mx51,
mx69,
mx128,
mx160,
mx27,
mx47,
mx123,
mx121);
auto mx185 = m.add_instruction(pass_op{}, mx180);
m.add_instruction(pass_op{},
mx185,
mx101,
mx8,
mx115,
mx130,
mx105,
mx38,
mx147,
mx93,
mx64,
mx154,
mx61,
mx98,
mx40,
mx27,
mx183,
mx174,
mx121,
mx30,
mx113,
mx73,
mx119,
mx36,
mx150,
mx69,
mx67,
mx47,
mx110,
mx32,
mx22,
mx169,
mx15,
mx18,
mx123,
mx143);
auto mx187 = m.add_instruction(pass_op{}, mx174);
m.add_instruction(pass_op{},
mx187,
mx150,
mx128,
mx67,
mx15,
mx88,
mx43,
mx79,
mx126,
mx103,
mx22,
mx90,
mx180,
mx183,
mx166,
mx141,
mx30,
mx20,
mx59,
mx55,
mx38,
mx160,
mx0,
mx32,
mx85,
mx6,
mx76,
mx157,
mx45,
mx162,
mx138,
mx154,
mx53,
mx177,
mx136,
mx51,
mx47,
mx81,
mx93,
mx73,
mx8,
mx110,
mx101,
mx69,
mx185,
mx36,
mx143,
mx147,
mx134,
mx1,
mx130,
mx115,
mx105,
mx40,
mx98);
m.add_instruction(pass_op{});
m.add_instruction(pass_op{});
m.add_instruction(pass_op{});
auto mx192 = m.add_instruction(pass_op{}, mx183, mx187, mx185);
m.add_instruction(pass_op{},
mx192,
mx150,
mx128,
mx67,
mx187,
mx15,
mx88,
mx43,
mx79,
mx126,
mx103,
mx64,
mx22,
mx90,
mx180,
mx141,
mx20,
mx59,
mx134,
mx1,
mx55,
mx113,
mx160,
mx0,
mx32,
mx85,
mx6,
mx76,
mx157,
mx45,
mx162,
mx138,
mx154,
mx53,
mx61,
mx177,
mx174,
mx136,
mx119,
mx185,
mx51,
mx47,
mx81,
mx73,
mx8,
mx110,
mx18,
mx169,
mx101,
mx69,
mx27,
mx123,
mx166,
mx121);
m.add_instruction(pass_op{});
auto mx195 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(pass_op{},
mx195,
mx115,
mx105,
mx98,
mx123,
mx27,
mx126,
mx103,
mx64,
mx183,
mx174,
mx136,
mx177,
mx141,
mx51,
mx93,
mx113,
mx38,
mx160,
mx55,
mx30,
mx61,
mx138,
mx53,
mx76,
mx85,
mx6,
mx20,
mx59,
mx0,
mx40,
mx43,
mx88,
mx79,
mx180,
mx90,
mx187,
mx81,
mx128,
mx157,
mx45,
mx162,
mx134,
mx1,
mx130,
mx147,
mx166,
mx121,
mx18,
mx169,
mx143,
mx119,
mx36,
mx185,
mx192);
auto mx197 = m.add_instruction(pass_op{}, mx174);
m.add_instruction(pass_op{},
mx197,
mx128,
mx150,
mx101,
mx69,
mx126,
mx103,
mx22,
mx166,
mx183,
mx136,
mx177,
mx141,
mx30,
mx73,
mx93,
mx38,
mx160,
mx55,
mx76,
mx32,
mx85,
mx6,
mx20,
mx59,
mx0,
mx43,
mx15,
mx88,
mx79,
mx180,
mx90,
mx67,
mx81,
mx138,
mx154,
mx53,
mx157,
mx45,
mx162,
mx51,
mx47,
mx195,
mx110,
mx8,
mx143,
mx147,
mx134,
mx1,
mx130,
mx115,
mx105,
mx40,
mx98,
mx36,
mx185,
mx192);
m.add_instruction(pass_op{});
m.add_instruction(pass_op{});
auto mx201 = m.add_instruction(pass_op{}, mx195, mx197);
m.add_instruction(pass_op{},
mx201,
mx115,
mx105,
mx98,
mx123,
mx27,
mx126,
mx103,
mx64,
mx183,
mx174,
mx136,
mx177,
mx141,
mx51,
mx93,
mx113,
mx38,
mx160,
mx55,
mx30,
mx61,
mx138,
mx53,
mx76,
mx85,
mx6,
mx20,
mx59,
mx0,
mx40,
mx43,
mx197,
mx88,
mx79,
mx180,
mx90,
mx187,
mx81,
mx128,
mx157,
mx45,
mx162,
mx134,
mx1,
mx130,
mx147,
mx166,
mx121,
mx18,
mx169,
mx143,
mx119,
mx36,
mx185,
mx192);
auto mx203 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(pass_op{},
mx203,
mx115,
mx105,
mx98,
mx123,
mx27,
mx126,
mx103,
mx64,
mx183,
mx174,
mx136,
mx177,
mx141,
mx51,
mx93,
mx113,
mx38,
mx160,
mx55,
mx30,
mx61,
mx138,
mx53,
mx76,
mx85,
mx6,
mx20,
mx59,
mx0,
mx40,
mx43,
mx197,
mx88,
mx79,
mx180,
mx90,
mx187,
mx81,
mx128,
mx157,
mx45,
mx162,
mx134,
mx1,
mx130,
mx147,
mx166,
mx121,
mx18,
mx169,
mx143,
mx119,
mx36,
mx185,
mx192);
auto mx205 = m.add_instruction(pass_op{}, mx203, mx201);
m.add_instruction(pass_op{},
mx205,
mx115,
mx105,
mx98,
mx123,
mx27,
mx126,
mx103,
mx64,
mx183,
mx174,
mx136,
mx177,
mx141,
mx51,
mx93,
mx113,
mx38,
mx160,
mx55,
mx30,
mx61,
mx138,
mx53,
mx76,
mx85,
mx6,
mx20,
mx59,
mx0,
mx40,
mx43,
mx197,
mx88,
mx79,
mx180,
mx90,
mx187,
mx81,
mx128,
mx157,
mx45,
mx162,
mx134,
mx1,
mx130,
mx147,
mx166,
mx121,
mx18,
mx169,
mx143,
mx119,
mx36,
mx185,
mx192);
m.add_instruction(pass_op{});
auto mx208 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(pass_op{},
mx208,
mx30,
mx40,
mx64,
mx93,
mx18,
mx98,
mx115,
mx143,
mx38,
mx147,
mx183,
mx197,
mx150,
mx119,
mx32,
mx8,
mx105,
mx101,
mx110,
mx195,
mx47,
mx27,
mx22,
mx205,
mx121,
mx67,
mx187,
mx113,
mx73,
mx201,
mx130,
mx203,
mx169,
mx69,
mx15,
mx154,
mx61,
mx174,
mx123,
mx36,
mx192);
auto mx210 = m.add_instruction(pass_op{}, mx180);
m.add_instruction(pass_op{},
mx210,
mx143,
mx115,
mx18,
mx93,
mx150,
mx47,
mx187,
mx15,
mx169,
mx69,
mx205,
mx32,
mx119,
mx113,
mx73,
mx201,
mx30,
mx67,
mx121,
mx22,
mx27,
mx40,
mx98,
mx174,
mx61,
mx154,
mx64,
mx147,
mx38,
mx203,
mx130,
mx8,
mx110,
mx105,
mx101,
mx195,
mx183,
mx197,
mx123,
mx36,
mx192);
auto mx212 = m.add_instruction(pass_op{}, mx174);
m.add_instruction(pass_op{},
mx212,
mx32,
mx67,
mx90,
mx15,
mx138,
mx126,
mx103,
mx38,
mx136,
mx180,
mx141,
mx51,
mx30,
mx22,
mx201,
mx59,
mx134,
mx154,
mx150,
mx1,
mx160,
mx45,
mx6,
mx76,
mx88,
mx53,
mx47,
mx183,
mx81,
mx157,
mx93,
mx79,
mx85,
mx0,
mx210,
mx73,
mx8,
mx110,
mx20,
mx69,
mx177,
mx36,
mx143,
mx162,
mx147,
mx130,
mx115,
mx55,
mx105,
mx40,
mx98,
mx208,
mx203,
mx128,
mx205,
mx195,
mx101,
mx185,
mx43,
mx166,
mx192);
m.add_instruction(pass_op{});
m.add_instruction(pass_op{});
auto mx216 = m.add_instruction(pass_op{}, mx208, mx212, mx210);
m.add_instruction(pass_op{},
mx216,
mx121,
mx30,
mx64,
mx93,
mx123,
mx143,
mx119,
mx36,
mx150,
mx8,
mx101,
mx169,
mx147,
mx110,
mx27,
mx61,
mx40,
mx205,
mx115,
mx32,
mx69,
mx67,
mx98,
mx187,
mx195,
mx73,
mx105,
mx183,
mx197,
mx22,
mx113,
mx201,
mx47,
mx130,
mx154,
mx15,
mx212,
mx18,
mx174,
mx38,
mx203,
mx192);
auto mx218 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(pass_op{},
mx218,
mx121,
mx30,
mx64,
mx93,
mx123,
mx143,
mx119,
mx36,
mx150,
mx8,
mx101,
mx169,
mx147,
mx110,
mx27,
mx61,
mx40,
mx205,
mx115,
mx32,
mx69,
mx67,
mx98,
mx187,
mx195,
mx73,
mx105,
mx183,
mx197,
mx22,
mx113,
mx201,
mx47,
mx130,
mx154,
mx15,
mx212,
mx18,
mx174,
mx38,
mx203,
mx192);
auto mx220 = m.add_instruction(pass_op{}, mx218, mx216, mx166);
m.add_instruction(pass_op{},
mx220,
mx121,
mx30,
mx64,
mx93,
mx123,
mx143,
mx119,
mx36,
mx150,
mx8,
mx101,
mx169,
mx147,
mx110,
mx27,
mx61,
mx40,
mx205,
mx115,
mx32,
mx69,
mx67,
mx98,
mx187,
mx195,
mx73,
mx105,
mx183,
mx197,
mx22,
mx113,
mx201,
mx47,
mx130,
mx154,
mx15,
mx212,
mx18,
mx174,
mx38,
mx203,
mx192);
m.add_instruction(pass_op{});
auto mx223 = m.add_instruction(pass_op{}, mx205, mx220, mx57, mx205);
m.add_instruction(pass_op{},
mx223,
mx38,
mx192,
mx203,
mx130,
mx47,
mx143,
mx123,
mx169,
mx121,
mx147,
mx110,
mx27,
mx36,
mx150,
mx119,
mx101,
mx8,
mx64,
mx61,
mx115,
mx32,
mx69,
mx67,
mx98,
mx187,
mx195,
mx73,
mx105,
mx183,
mx197,
mx22,
mx113,
mx201,
mx174,
mx18,
mx93,
mx205,
mx40,
mx30,
mx154,
mx15,
mx212);
auto mx225 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(pass_op{},
mx225,
mx45,
mx59,
mx76,
mx90,
mx218,
mx67,
mx126,
mx103,
mx136,
mx138,
mx15,
mx32,
mx1,
mx160,
mx150,
mx110,
mx51,
mx30,
mx6,
mx157,
mx93,
mx79,
mx85,
mx88,
mx53,
mx154,
mx134,
mx141,
mx180,
mx38,
mx81,
mx223,
mx183,
mx220,
mx210,
mx0,
mx208,
mx20,
mx69,
mx73,
mx185,
mx101,
mx201,
mx22,
mx203,
mx47,
mx128,
mx205,
mx195,
mx8,
mx177,
mx36,
mx55,
mx216,
mx105,
mx115,
mx130,
mx40,
mx98,
mx43,
mx166,
mx192,
mx162,
mx147,
mx143);
m.add_instruction(pass_op{});
m.add_instruction(pass_op{});
auto mx229 = m.add_instruction(pass_op{}, mx225, mx69, mx192);
m.add_instruction(pass_op{},
mx229,
mx45,
mx59,
mx76,
mx90,
mx218,
mx67,
mx126,
mx103,
mx136,
mx138,
mx15,
mx32,
mx1,
mx160,
mx150,
mx110,
mx51,
mx30,
mx6,
mx157,
mx93,
mx79,
mx85,
mx88,
mx53,
mx154,
mx134,
mx141,
mx180,
mx38,
mx81,
mx223,
mx183,
mx220,
mx210,
mx0,
mx208,
mx20,
mx69,
mx73,
mx185,
mx101,
mx201,
mx22,
mx203,
mx47,
mx128,
mx205,
mx195,
mx8,
mx177,
mx36,
mx55,
mx216,
mx105,
mx115,
mx130,
mx40,
mx98,
mx43,
mx166,
mx192,
mx162,
mx147,
mx143);
m.add_instruction(pass_op{});
auto mx232 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(pass_op{},
mx232,
mx160,
mx154,
mx76,
mx43,
mx67,
mx55,
mx187,
mx88,
mx126,
mx197,
mx225,
mx136,
mx59,
mx64,
mx15,
mx212,
mx128,
mx32,
mx218,
mx150,
mx216,
mx110,
mx169,
mx103,
mx113,
mx141,
mx79,
mx223,
mx90,
mx6,
mx18,
mx138,
mx210,
mx85,
mx53,
mx61,
mx45,
mx134,
mx119,
mx180,
mx166,
mx20,
mx0,
mx177,
mx81,
mx208,
mx157,
mx185,
mx1,
mx69,
mx201,
mx174,
mx101,
mx51,
mx22,
mx162,
mx220,
mx203,
mx47,
mx195,
mx73,
mx27,
mx205,
mx229,
mx8,
mx123,
mx121);
m.add_instruction(pass_op{});
m.add_instruction(pass_op{});
auto mx236 = m.add_instruction(pass_op{}, mx232, mx192, mx166);
m.add_instruction(pass_op{},
mx236,
mx160,
mx154,
mx76,
mx43,
mx67,
mx55,
mx187,
mx88,
mx126,
mx197,
mx225,
mx136,
mx59,
mx64,
mx15,
mx212,
mx128,
mx32,
mx218,
mx150,
mx216,
mx110,
mx169,
mx103,
mx113,
mx141,
mx79,
mx223,
mx90,
mx6,
mx18,
mx138,
mx210,
mx85,
mx53,
mx61,
mx45,
mx134,
mx119,
mx180,
mx166,
mx20,
mx0,
mx177,
mx81,
mx208,
mx157,
mx185,
mx1,
mx69,
mx201,
mx174,
mx101,
mx51,
mx22,
mx162,
mx220,
mx203,
mx47,
mx195,
mx73,
mx27,
mx205,
mx229,
mx8,
mx123,
mx121);
m.add_instruction(pass_op{});
auto mx239 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(pass_op{},
mx239,
mx38,
mx192,
mx232,
mx203,
mx229,
mx183,
mx154,
mx201,
mx113,
mx174,
mx110,
mx197,
mx36,
mx115,
mx150,
mx98,
mx130,
mx32,
mx101,
mx169,
mx8,
mx64,
mx27,
mx225,
mx22,
mx147,
mx67,
mx205,
mx73,
mx61,
mx105,
mx18,
mx47,
mx123,
mx93,
mx195,
mx119,
mx69,
mx40,
mx187,
mx30,
mx15,
mx143,
mx236,
mx121,
mx212);
m.add_instruction(pass_op{});
auto mx242 = m.add_instruction(pass_op{}, mx239, mx223);
m.add_instruction(pass_op{},
mx242,
mx38,
mx192,
mx232,
mx203,
mx229,
mx183,
mx154,
mx201,
mx113,
mx174,
mx110,
mx197,
mx36,
mx115,
mx150,
mx98,
mx130,
mx32,
mx101,
mx169,
mx8,
mx64,
mx27,
mx225,
mx22,
mx147,
mx67,
mx205,
mx73,
mx61,
mx105,
mx18,
mx47,
mx123,
mx93,
mx195,
mx119,
mx69,
mx40,
mx187,
mx30,
mx15,
mx143,
mx236,
mx121,
mx212);
auto mx244 = add_alloc(m, migraphx::shape{migraphx::shape::float_type, {2, 5}});
m.add_instruction(pass_op{});
m.add_instruction(pass_op{});
auto mx247 = m.add_instruction(pass_op{}, mx244, mx229, mx242, mx236);
auto moutput =
m.add_parameter("output", migraphx::shape{migraphx::shape::float_type, {3, 1, 2, 5}});
auto mx248 = m.add_instruction(pass_op{}, mx247);
auto mx249 = m.add_instruction(pass_op{}, mx166);
auto mx250 = m.add_instruction(pass_op{}, mx85);
m.add_instruction(pass_op{}, moutput, mx250, mx249, mx248);
run_pass(m);
CHECK(m.get_parameter_shape("scratch").bytes() == 1600);
CHECK(no_allocate(m));
CHECK(is_disjoint({mx0, mx8}));
CHECK(is_disjoint({mx0, mx8}));
CHECK(is_disjoint({mx0, mx18, mx8}));
CHECK(is_disjoint({mx0, mx18, mx8}));
CHECK(is_disjoint({mx0, mx18, mx8}));
CHECK(is_disjoint({mx0, mx18, mx8}));
CHECK(is_disjoint({mx0, mx18, mx8}));
CHECK(is_disjoint({mx0, mx18, mx30, mx8}));
CHECK(is_disjoint({mx0, mx18, mx30, mx8}));
CHECK(is_disjoint({mx30, mx8}));
CHECK(is_disjoint({mx0, mx18, mx30, mx8}));
CHECK(is_disjoint({mx0, mx18, mx38, mx8}));
CHECK(is_disjoint({mx30, mx38}));
CHECK(is_disjoint({mx0, mx18, mx38, mx8}));
CHECK(is_disjoint({mx18, mx30, mx38, mx43, mx8}));
CHECK(is_disjoint({mx0, mx18, mx30, mx38, mx8}));
CHECK(is_disjoint({mx0, mx18, mx30, mx38, mx43, mx8}));
CHECK(is_disjoint({mx0, mx43, mx8}));
CHECK(is_disjoint({mx18, mx30, mx38, mx43, mx8}));
CHECK(is_disjoint({mx18, mx30, mx38, mx53, mx8}));
CHECK(is_disjoint({mx43, mx53}));
CHECK(is_disjoint({mx18, mx30, mx38, mx53, mx8}));
CHECK(is_disjoint({mx38, mx53}));
CHECK(is_disjoint({mx18, mx30, mx38, mx8}));
CHECK(is_disjoint({mx0, mx30, mx38, mx43, mx53, mx61, mx8}));
CHECK(is_disjoint({mx18, mx61}));
CHECK(is_disjoint({mx0, mx30, mx38, mx43, mx53, mx61, mx8}));
CHECK(is_disjoint({mx0, mx18, mx30, mx38, mx43, mx53, mx61, mx67}));
CHECK(is_disjoint({mx0, mx18, mx30, mx38, mx43, mx53, mx61}));
CHECK(is_disjoint({mx18, mx67}));
CHECK(is_disjoint({mx0, mx18, mx30, mx38, mx43, mx53, mx61, mx67}));
CHECK(is_disjoint({mx18, mx30, mx38, mx61, mx67, mx76, mx8}));
CHECK(is_disjoint({mx38, mx76}));
CHECK(is_disjoint({mx18, mx30, mx38, mx61, mx67, mx76, mx8}));
CHECK(is_disjoint({mx18, mx30, mx38, mx61, mx67, mx8, mx81}));
CHECK(is_disjoint({mx61, mx67, mx76, mx81}));
CHECK(is_disjoint({mx18, mx30, mx38, mx61, mx67, mx8, mx81}));
CHECK(is_disjoint({mx18, mx30, mx38, mx61, mx67, mx8, mx88}));
CHECK(is_disjoint({mx81, mx88}));
CHECK(is_disjoint({mx18, mx30, mx38, mx61, mx67, mx8, mx88}));
CHECK(is_disjoint({mx0, mx18, mx38, mx43, mx53, mx61, mx67, mx76, mx8, mx81, mx88, mx93}));
CHECK(is_disjoint({mx0, mx18, mx38, mx43, mx53, mx61, mx67, mx76, mx8, mx81, mx88, mx93}));
CHECK(is_disjoint({mx0, mx101, mx18, mx30, mx38, mx43, mx53, mx61, mx76, mx81, mx88, mx93}));
CHECK(is_disjoint({mx101, mx18, mx30, mx38, mx61, mx67, mx8, mx88, mx93}));
CHECK(
is_disjoint({mx0, mx101, mx18, mx38, mx43, mx53, mx61, mx67, mx76, mx8, mx81, mx88, mx93}));
CHECK(is_disjoint({mx101, mx88, mx93}));
CHECK(is_disjoint({mx0, mx101, mx18, mx30, mx38, mx43, mx53, mx61, mx76, mx81, mx88, mx93}));
CHECK(is_disjoint(
{mx0, mx101, mx113, mx30, mx38, mx43, mx53, mx67, mx76, mx8, mx81, mx88, mx93}));
CHECK(is_disjoint(
{mx0, mx101, mx113, mx18, mx38, mx43, mx53, mx61, mx67, mx76, mx8, mx81, mx88, mx93}));
CHECK(is_disjoint({mx113, mx93}));
CHECK(is_disjoint(
{mx0, mx101, mx113, mx30, mx38, mx43, mx53, mx67, mx76, mx8, mx81, mx88, mx93}));
CHECK(is_disjoint(
{mx0, mx101, mx121, mx30, mx38, mx43, mx53, mx67, mx76, mx8, mx81, mx88, mx93}));
CHECK(is_disjoint({mx113, mx121}));
CHECK(is_disjoint(
{mx0, mx101, mx121, mx30, mx38, mx43, mx53, mx67, mx76, mx8, mx81, mx88, mx93}));
CHECK(is_disjoint({mx101, mx113, mx121, mx126, mx18, mx30, mx38, mx61, mx67, mx8, mx93}));
CHECK(is_disjoint({mx101, mx113, mx121, mx18, mx30, mx38, mx61, mx67, mx8, mx88, mx93}));
CHECK(is_disjoint({mx0,
mx101,
mx113,
mx121,
mx126,
mx18,
mx38,
mx43,
mx53,
mx61,
mx67,
mx76,
mx8,
mx81,
mx88,
mx93}));
CHECK(is_disjoint({mx126, mx88, mx93}));
CHECK(is_disjoint({mx101, mx113, mx121, mx126, mx18, mx30, mx38, mx61, mx67, mx8, mx93}));
CHECK(is_disjoint({mx101, mx113, mx121, mx136, mx18, mx30, mx38, mx61, mx67, mx8, mx93}));
CHECK(is_disjoint({mx126, mx136, mx81}));
CHECK(is_disjoint({mx101, mx113, mx121, mx136, mx18, mx30, mx38, mx61, mx67, mx8, mx93}));
CHECK(is_disjoint({mx121, mx136}));
CHECK(is_disjoint({mx101, mx113, mx121, mx18, mx30, mx38, mx61, mx67, mx8, mx93}));
CHECK(is_disjoint({mx0,
mx101,
mx113,
mx121,
mx126,
mx136,
mx143,
mx18,
mx38,
mx43,
mx53,
mx61,
mx67,
mx76,
mx8,
mx81,
mx88}));
CHECK(is_disjoint({mx101, mx143}));
CHECK(is_disjoint({mx0,
mx101,
mx113,
mx121,
mx126,
mx136,
mx143,
mx18,
mx38,
mx43,
mx53,
mx61,
mx67,
mx76,
mx8,
mx81,
mx88}));
CHECK(is_disjoint({mx0,
mx113,
mx121,
mx126,
mx136,
mx143,
mx150,
mx18,
mx30,
mx38,
mx43,
mx53,
mx61,
mx76,
mx81,
mx88,
mx93}));
CHECK(is_disjoint({mx101, mx150, mx81}));
CHECK(is_disjoint({mx0,
mx113,
mx121,
mx126,
mx136,
mx143,
mx150,
mx18,
mx30,
mx38,
mx43,
mx53,
mx61,
mx76,
mx81,
mx88,
mx93}));
CHECK(is_disjoint(
{mx101, mx113, mx121, mx143, mx150, mx157, mx18, mx30, mx38, mx61, mx67, mx8, mx93}));
CHECK(is_disjoint({mx121, mx157}));
CHECK(is_disjoint(
{mx101, mx113, mx121, mx143, mx150, mx157, mx18, mx30, mx38, mx61, mx67, mx8, mx93}));
CHECK(is_disjoint(
{mx101, mx113, mx121, mx143, mx150, mx162, mx18, mx30, mx38, mx61, mx67, mx8, mx93}));
CHECK(is_disjoint({mx143, mx150, mx157, mx162}));
CHECK(is_disjoint(
{mx101, mx113, mx121, mx143, mx150, mx162, mx18, mx30, mx38, mx61, mx67, mx8, mx93}));
CHECK(is_disjoint({mx0, mx101, mx121, mx126, mx136, mx143, mx150, mx157, mx162, mx169,
mx30, mx38, mx43, mx53, mx67, mx76, mx8, mx81, mx88, mx93}));
CHECK(is_disjoint({mx0, mx101, mx121, mx126, mx136, mx143, mx150, mx157, mx162, mx169,
mx30, mx38, mx43, mx53, mx67, mx76, mx8, mx81, mx88, mx93}));
CHECK(is_disjoint({mx101,
mx113,
mx121,
mx143,
mx150,
mx169,
mx177,
mx18,
mx30,
mx38,
mx61,
mx67,
mx8,
mx93}));
CHECK(is_disjoint({mx162, mx177}));
CHECK(is_disjoint({mx101,
mx113,
mx121,
mx143,
mx150,
mx169,
mx177,
mx18,
mx30,
mx38,
mx61,
mx67,
mx8,
mx93}));
CHECK(is_disjoint({mx0, mx101, mx113, mx121, mx126, mx136, mx150, mx157, mx162, mx169, mx177,
mx18, mx183, mx38, mx43, mx53, mx61, mx67, mx76, mx8, mx81, mx88}));
CHECK(is_disjoint({mx101,
mx113,
mx121,
mx143,
mx150,
mx169,
mx177,
mx18,
mx183,
mx30,
mx38,
mx61,
mx67,
mx8,
mx93}));
CHECK(
is_disjoint({mx0, mx101, mx121, mx126, mx136, mx143, mx150, mx157, mx162, mx169, mx177,
mx183, mx30, mx38, mx43, mx53, mx67, mx76, mx8, mx81, mx88, mx93}));
CHECK(is_disjoint({mx169, mx177, mx183}));
CHECK(is_disjoint({mx0, mx101, mx113, mx121, mx126, mx136, mx150, mx157, mx162, mx169, mx177,
mx18, mx183, mx38, mx43, mx53, mx61, mx67, mx76, mx8, mx81, mx88}));
CHECK(
is_disjoint({mx0, mx113, mx121, mx126, mx136, mx143, mx157, mx162, mx169, mx177, mx18,
mx183, mx195, mx30, mx38, mx43, mx53, mx61, mx76, mx81, mx88, mx93}));
CHECK(is_disjoint({mx0, mx101, mx121, mx126, mx136, mx143, mx150, mx157,
mx162, mx169, mx177, mx183, mx195, mx30, mx38, mx43,
mx53, mx67, mx76, mx8, mx81, mx88, mx93}));
CHECK(is_disjoint({mx169, mx195}));
CHECK(
is_disjoint({mx0, mx113, mx121, mx126, mx136, mx143, mx157, mx162, mx169, mx177, mx18,
mx183, mx195, mx30, mx38, mx43, mx53, mx61, mx76, mx81, mx88, mx93}));
CHECK(
is_disjoint({mx0, mx113, mx121, mx126, mx136, mx143, mx157, mx162, mx169, mx177, mx18,
mx183, mx203, mx30, mx38, mx43, mx53, mx61, mx76, mx81, mx88, mx93}));
CHECK(is_disjoint({mx195, mx203}));
CHECK(
is_disjoint({mx0, mx113, mx121, mx126, mx136, mx143, mx157, mx162, mx169, mx177, mx18,
mx183, mx203, mx30, mx38, mx43, mx53, mx61, mx76, mx81, mx88, mx93}));
CHECK(is_disjoint({mx101,
mx113,
mx121,
mx143,
mx150,
mx169,
mx18,
mx183,
mx195,
mx203,
mx208,
mx30,
mx38,
mx61,
mx67,
mx8,
mx93}));
CHECK(is_disjoint({mx101,
mx113,
mx121,
mx143,
mx150,
mx169,
mx177,
mx18,
mx183,
mx195,
mx203,
mx30,
mx38,
mx61,
mx67,
mx8,
mx93}));
CHECK(is_disjoint({mx0, mx101, mx121, mx126, mx136, mx143, mx150, mx157, mx162,
mx169, mx177, mx183, mx195, mx203, mx208, mx30, mx38, mx43,
mx53, mx67, mx76, mx8, mx81, mx88, mx93}));
CHECK(is_disjoint({mx169, mx177, mx208}));
CHECK(is_disjoint({mx101,
mx113,
mx121,
mx143,
mx150,
mx169,
mx18,
mx183,
mx195,
mx203,
mx208,
mx30,
mx38,
mx61,
mx67,
mx8,
mx93}));
CHECK(is_disjoint({mx101,
mx113,
mx121,
mx143,
mx150,
mx169,
mx18,
mx183,
mx195,
mx203,
mx218,
mx30,
mx38,
mx61,
mx67,
mx8,
mx93}));
CHECK(is_disjoint({mx162, mx208, mx218}));
CHECK(is_disjoint({mx101,
mx113,
mx121,
mx143,
mx150,
mx169,
mx18,
mx183,
mx195,
mx203,
mx218,
mx30,
mx38,
mx61,
mx67,
mx8,
mx93}));
CHECK(is_disjoint({mx203, mx218}));
CHECK(is_disjoint({mx101,
mx113,
mx121,
mx143,
mx150,
mx169,
mx18,
mx183,
mx195,
mx203,
mx30,
mx38,
mx61,
mx67,
mx8,
mx93}));
CHECK(is_disjoint({mx0, mx101, mx121, mx126, mx136, mx143, mx150, mx157, mx162,
mx177, mx183, mx195, mx203, mx208, mx218, mx225, mx30, mx38,
mx43, mx53, mx67, mx76, mx8, mx81, mx88, mx93}));
CHECK(is_disjoint({mx183, mx225}));
CHECK(is_disjoint({mx0, mx101, mx121, mx126, mx136, mx143, mx150, mx157, mx162,
mx177, mx183, mx195, mx203, mx208, mx218, mx225, mx30, mx38,
mx43, mx53, mx67, mx76, mx8, mx81, mx88, mx93}));
CHECK(is_disjoint({mx0, mx101, mx113, mx121, mx126, mx136, mx150, mx157, mx162,
mx169, mx177, mx18, mx195, mx203, mx208, mx218, mx225, mx232,
mx38, mx43, mx53, mx61, mx67, mx76, mx8, mx81, mx88}));
CHECK(is_disjoint({mx162, mx183, mx232}));
CHECK(is_disjoint({mx0, mx101, mx113, mx121, mx126, mx136, mx150, mx157, mx162,
mx169, mx177, mx18, mx195, mx203, mx208, mx218, mx225, mx232,
mx38, mx43, mx53, mx61, mx67, mx76, mx8, mx81, mx88}));
CHECK(is_disjoint({mx101,
mx113,
mx121,
mx143,
mx150,
mx169,
mx18,
mx183,
mx195,
mx203,
mx225,
mx232,
mx239,
mx30,
mx38,
mx61,
mx67,
mx8,
mx93}));
CHECK(is_disjoint({mx203, mx239}));
CHECK(is_disjoint({mx101,
mx113,
mx121,
mx143,
mx150,
mx169,
mx18,
mx183,
mx195,
mx203,
mx225,
mx232,
mx239,
mx30,
mx38,
mx61,
mx67,
mx8,
mx93}));
CHECK(is_disjoint({mx225, mx232, mx239, mx244}));
CHECK(is_disjoint({mx162, mx244, mx81}));
}
TEST_CASE(literal_test)
{
migraphx::program p;
......
#include <migraphx/module.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/ranges.hpp>
#include <sstream>
......@@ -24,33 +25,102 @@ migraphx::program create_program()
return p;
}
TEST_CASE(module_ins_clear)
TEST_CASE(calc_implict_deps)
{
migraphx::program p1 = create_program();
migraphx::program p2;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
migraphx::shape xs{migraphx::shape::float_type, {2, 3}};
migraphx::shape ys{migraphx::shape::float_type, {3, 3}};
std::vector<float> datax = {1, 2, 3, 4, 5, 6};
std::vector<float> datay = {8, 7, 6, 5, 4, 3, 2, 1, 0};
p2 = p1;
auto lx = mm->add_literal(migraphx::literal(xs, datax));
auto ly = mm->add_literal(migraphx::literal(ys, datay));
auto cond = mm->add_parameter("cond", cond_s);
auto x1 = mm->add_parameter("x1", xs);
auto x2 = mm->add_parameter("x2", xs);
auto y2 = mm->add_parameter("y2", ys);
EXPECT(p1 == p2);
auto* then_mod = p.create_module("If_5_if");
auto l1 = then_mod->add_literal(migraphx::literal(ys, datay));
auto a1 = then_mod->add_instruction(migraphx::make_op("add"), x1, lx);
then_mod->add_return({a1, l1});
auto* then_mod1 = p.create_module("If_6_if");
auto l11 = then_mod1->add_literal(migraphx::literal(ys, datay));
auto a11 = then_mod1->add_instruction(migraphx::make_op("add"), x2, lx);
then_mod1->add_return({a11, l11});
auto* else_mod1 = p.create_module("If_6_else");
auto l21 = else_mod1->add_literal(migraphx::literal(xs, datax));
auto a21 = else_mod1->add_instruction(migraphx::make_op("mul"), y2, ly);
else_mod1->add_return({l21, a21});
auto* else_mod = p.create_module("If_5_else");
auto l2 = else_mod->add_literal(migraphx::literal(ys, datay));
auto a2 = else_mod->add_instruction(migraphx::make_op("if"), {cond}, {then_mod1, else_mod1});
auto a3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), a2);
else_mod->add_return({a3, l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
auto implicit_deps = mm->calc_implicit_deps();
EXPECT(migraphx::contains(implicit_deps, ret));
EXPECT(migraphx::contains(implicit_deps.at(ret), x1));
EXPECT(migraphx::contains(implicit_deps.at(ret), x2));
EXPECT(migraphx::contains(implicit_deps.at(ret), y2));
}
TEST_CASE(module_print_graph)
TEST_CASE(module_annotate)
{
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
auto* mm1 = p1.get_main_module();
auto* mm2 = p2.get_main_module();
EXPECT(*mm1 == *mm2);
std::stringstream ss1;
mm1->print_graph(ss1, true);
mm1->annotate(ss1, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; });
std::stringstream ss2;
mm2->print_graph(ss2, true);
mm2->annotate(ss2, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; });
EXPECT(ss1.str() == ss2.str());
}
TEST_CASE(module_ins_clear)
{
migraphx::program p1 = create_program();
migraphx::program p2;
p2 = p1;
EXPECT(p1 == p2);
}
TEST_CASE(module_name)
{
migraphx::module m1("name");
EXPECT(m1.name() == "name");
auto m2 = m1; // NOLINT
EXPECT(m2.name() == "name");
migraphx::module m3;
m3 = m1;
EXPECT(m3.name() == "name");
}
TEST_CASE(module_name_main)
{
migraphx::program p;
auto* mm = p.get_main_module();
EXPECT(mm->name() == "main");
}
TEST_CASE(module_print_cpp)
{
migraphx::program p1 = create_program();
......@@ -68,43 +138,23 @@ TEST_CASE(module_print_cpp)
EXPECT(ss1.str() == ss2.str());
}
TEST_CASE(module_annotate)
TEST_CASE(module_print_graph)
{
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
auto* mm1 = p1.get_main_module();
auto* mm2 = p2.get_main_module();
EXPECT(*mm1 == *mm2);
std::stringstream ss1;
mm1->annotate(ss1, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; });
mm1->print_graph(ss1, true);
std::stringstream ss2;
mm2->annotate(ss2, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; });
mm2->print_graph(ss2, true);
EXPECT(ss1.str() == ss2.str());
}
TEST_CASE(module_name)
{
migraphx::module m1("name");
EXPECT(m1.name() == "name");
auto m2 = m1; // NOLINT
EXPECT(m2.name() == "name");
migraphx::module m3;
m3 = m1;
EXPECT(m3.name() == "name");
}
TEST_CASE(module_name_main)
{
migraphx::program p;
auto* mm = p.get_main_module();
EXPECT(mm->name() == "main");
}
TEST_CASE(program_module_assign)
{
migraphx::program p;
......@@ -204,51 +254,62 @@ TEST_CASE(submodule_copy)
EXPECT(mm.get_sub_modules() == mm2.get_sub_modules());
}
TEST_CASE(calc_implict_deps)
TEST_CASE(parameter_name_order)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
migraphx::shape xs{migraphx::shape::float_type, {2, 3}};
migraphx::shape ys{migraphx::shape::float_type, {3, 3}};
std::vector<float> datax = {1, 2, 3, 4, 5, 6};
std::vector<float> datay = {8, 7, 6, 5, 4, 3, 2, 1, 0};
auto lx = mm->add_literal(migraphx::literal(xs, datax));
auto ly = mm->add_literal(migraphx::literal(ys, datay));
auto cond = mm->add_parameter("cond", cond_s);
auto x1 = mm->add_parameter("x1", xs);
auto x2 = mm->add_parameter("x2", xs);
auto y2 = mm->add_parameter("y2", ys);
auto* then_mod = p.create_module("If_5_if");
auto l1 = then_mod->add_literal(migraphx::literal(ys, datay));
auto a1 = then_mod->add_instruction(migraphx::make_op("add"), x1, lx);
then_mod->add_return({a1, l1});
auto* then_mod1 = p.create_module("If_6_if");
auto l11 = then_mod1->add_literal(migraphx::literal(ys, datay));
auto a11 = then_mod1->add_instruction(migraphx::make_op("add"), x2, lx);
then_mod1->add_return({a11, l11});
auto* else_mod1 = p.create_module("If_6_else");
auto l21 = else_mod1->add_literal(migraphx::literal(xs, datax));
auto a21 = else_mod1->add_instruction(migraphx::make_op("mul"), y2, ly);
else_mod1->add_return({l21, a21});
auto* else_mod = p.create_module("If_5_else");
auto l2 = else_mod->add_literal(migraphx::literal(ys, datay));
auto a2 = else_mod->add_instruction(migraphx::make_op("if"), {cond}, {then_mod1, else_mod1});
else_mod->add_return({a2, l2});
migraphx::shape s{migraphx::shape::int32_type, {1}};
migraphx::module mm("main");
auto x1 = mm.add_parameter("x1", s);
auto x2 = mm.add_parameter("x2", s);
auto x3 = mm.add_parameter("x3", s);
auto x4 = mm.add_parameter("x4", s);
std::vector<std::string> param_names = {"x1", "x2", "x3", "x4"};
auto sum1 = mm.add_instruction(migraphx::make_op("add"), x1, x2);
auto sum2 = mm.add_instruction(migraphx::make_op("add"), x3, x4);
auto r = mm.add_instruction(migraphx::make_op("mul"), sum1, sum2);
mm.add_return({r});
auto names = mm.get_parameter_names();
EXPECT(param_names == names);
auto m1 = mm;
auto names1 = m1.get_parameter_names();
EXPECT(param_names == names1);
}
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
mm->add_return({ret});
struct check_for_pass_op
{
bool* found = nullptr;
std::string name() const { return "check_for_pass_op"; }
void apply(migraphx::module& m) const
{
*found |= std::any_of(m.begin(), m.end(), [](auto&& ins) { return ins.name() == "pass"; });
}
};
TEST_CASE(module_bypass)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto* sub = p.create_module("sub");
sub->set_bypass();
sub->add_instruction(pass_op{});
mm->add_instruction(mod_pass_op{}, {}, {sub});
bool found = false;
migraphx::run_passes(p, {check_for_pass_op{&found}});
EXPECT(not found);
}
auto implicit_deps = mm->calc_implicit_deps();
EXPECT(migraphx::contains(implicit_deps, ret));
EXPECT(migraphx::contains(implicit_deps.at(ret), x1));
EXPECT(migraphx::contains(implicit_deps.at(ret), x2));
EXPECT(migraphx::contains(implicit_deps.at(ret), y2));
TEST_CASE(module_without_bypass)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto* sub = p.create_module("sub");
sub->add_instruction(pass_op{});
mm->add_instruction(mod_pass_op{}, {}, {sub});
bool found = false;
migraphx::run_passes(p, {check_for_pass_op{&found}});
EXPECT(found);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -75,6 +75,26 @@ TEST_CASE(gather_test_1)
EXPECT(m1 == m2);
}
migraphx::module create_padded_op(const std::vector<size_t>& pad_vals)
{
migraphx::module m;
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 5}};
auto si = m.add_parameter("data", s);
auto r = m.add_instruction(migraphx::make_op("pooling", {{"padding", pad_vals}}), si);
m.add_return({r});
return m;
}
TEST_CASE(padding_attr_test)
{
migraphx::module m1 = create_padded_op({0, 1});
migraphx::module m2 = create_padded_op({0, 1, 0, 1});
run_pass(m1);
EXPECT(m1 == m2);
}
migraphx::module create_reduce_mean(const std::vector<int64_t>& axes)
{
migraphx::module m;
......
depthtospace_crd_test:
6
xy" DepthToSpace*
blocksize*
mode"CRDdepthtospace_crd_testZ
x




b
y





B
\ No newline at end of file
depthtospace_simple_test:
6
xy" DepthToSpace*
blocksize*
mode"DCRdepthtospace_simple_testZ
x




b
y




B
\ No newline at end of file
depthtospace_test:
6
xy" DepthToSpace*
blocksize*
mode"DCRdepthtospace_testZ
x




b
y





B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment