Commit a88074c2 authored by Umang Yadav's avatar Umang Yadav
Browse files

Nested if_then_else and other tests

parent 454d9c08
...@@ -44,15 +44,173 @@ ...@@ -44,15 +44,173 @@
#include <migraphx/target_assignments.hpp> #include <migraphx/target_assignments.hpp>
#include <test.hpp> #include <test.hpp>
TEST_CASE(simple_no_branch_test)
{
/*
Add (tid = 1)
|
Mul (tid = 0)
|
Return
*/
migraphx::target_assignments tass;
auto s = migraphx::shape{migraphx::shape::float_type, {8}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s);
auto z_param = mm->add_parameter("z", s);
auto cpu_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param);
auto gpu_ins = mm->add_instruction(migraphx::make_op("mul"), cpu_ins, z_param);
mm->add_return({gpu_ins});
tass.insert(tass.begin(), std::make_pair(cpu_ins, 1));
tass.insert(tass.begin(), std::make_pair(gpu_ins, 0));
}
migraphx::generate_root_modules(p1, tass);
migraphx::program p2;
{
migraphx::module_ref mm = p2.get_main_module();
auto z = mm->add_parameter("z", s);
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0");
auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter(
"param:1", migraphx::shape{migraphx::shape::float_type, {8}});
auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter(
"param:0", migraphx::shape{migraphx::shape::float_type, {8}});
auto x_target_mod_1_0_2 = target_mod_1_0->add_instruction(
migraphx::make_op("add"), target_mod_1_0_param_1, target_mod_1_0_param_0);
target_mod_1_0->add_return({x_target_mod_1_0_2});
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter(
"param:1", migraphx::shape{migraphx::shape::float_type, {8}});
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter(
"param:0", migraphx::shape{migraphx::shape::float_type, {8}});
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("mul"), target_mod_0_0_param_1, target_mod_0_0_param_0);
target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_3 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 1}}), {y, x}, {target_mod_1_0});
auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3);
auto x_5 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {z, x_4}, {target_mod_0_0});
auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5);
mm->add_return({x_6});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(if_then_else_program)
{
/*
If -----------------> Return
|
---------------
| |
(then_mod) (else_mod)
| |
Add (tid = 0) Mul (tid = 1)
*/
migraphx::target_assignments tass;
migraphx::shape cond_s{migraphx::shape::bool_type};
migraphx::shape ds{migraphx::shape::float_type, {2, 3}};
std::vector<float> data1(ds.elements(), 1);
std::vector<float> data2(ds.elements(), 2);
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto cond = mm->add_parameter("cond", cond_s);
auto x = mm->add_parameter("x", ds);
auto y = mm->add_parameter("y", ds);
auto* then_mod = p1.create_module("if_gpu_mod");
auto l1 = then_mod->add_literal(migraphx::literal(ds, data1));
auto a1 = then_mod->add_instruction(migraphx::make_op("add"), x, l1);
then_mod->add_return({a1});
auto* else_mod = p1.create_module("else_cpu_mod");
auto l2 = else_mod->add_literal(migraphx::literal(ds, data2));
auto a2 = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
else_mod->add_return({a2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
tass.insert(tass.begin(), std::make_pair(l1, 0));
tass.insert(tass.begin(), std::make_pair(a1, 0));
tass.insert(tass.begin(), std::make_pair(l2, 1));
tass.insert(tass.begin(), std::make_pair(a2, 1));
}
migraphx::generate_root_modules(p1, tass);
migraphx::program p2;
{
migraphx::module_ref mm = p2.get_main_module();
auto x = mm->add_parameter("x", ds);
auto y = mm->add_parameter("y", ds);
auto cond = mm->add_parameter("cond", cond_s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter(
"param:1", migraphx::shape{migraphx::shape::float_type, {2, 3}});
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter(
"param:0", migraphx::shape{migraphx::shape::float_type, {2, 3}});
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0);
target_mod_0_0->add_return({x_target_mod_0_0_2});
migraphx::module_ref if_gpu_mod = p2.create_module("if_gpu_mod");
auto x_if_gpu_mod_0 = if_gpu_mod->add_literal(migraphx::literal(ds, data1));
auto x_if_gpu_mod_1 =
if_gpu_mod->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 0}}),
{x_if_gpu_mod_0, x},
{target_mod_0_0});
auto x_if_gpu_mod_2 = if_gpu_mod->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_if_gpu_mod_1);
if_gpu_mod->add_return({x_if_gpu_mod_2});
migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0");
auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter(
"param:1", migraphx::shape{migraphx::shape::float_type, {2, 3}});
auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter(
"param:0", migraphx::shape{migraphx::shape::float_type, {2, 3}});
auto x_target_mod_1_0_2 = target_mod_1_0->add_instruction(
migraphx::make_op("mul"), target_mod_1_0_param_1, target_mod_1_0_param_0);
target_mod_1_0->add_return({x_target_mod_1_0_2});
migraphx::module_ref else_cpu_mod = p2.create_module("else_cpu_mod");
auto x_else_cpu_mod_0 = else_cpu_mod->add_literal(migraphx::literal(ds, data2));
auto x_else_cpu_mod_1 =
else_cpu_mod->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 1}}),
{x_else_cpu_mod_0, y},
{target_mod_1_0});
auto x_else_cpu_mod_2 = else_cpu_mod->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_else_cpu_mod_1);
else_cpu_mod->add_return({x_else_cpu_mod_2});
auto x_3 = mm->add_instruction(migraphx::make_op("if"), {cond}, {if_gpu_mod, else_cpu_mod});
auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3);
mm->add_return({x_4});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(fork_case) TEST_CASE(fork_case)
{ {
/* /*
Add (tid = 0) Add (tid = 0)
| |
---------------- ---------------
| | | |
Mul Identity Mul Identity
(tid = 0) (tid = 1) (tid = 0) (tid = 1)
| |
Return Return
*/ */
auto s = migraphx::shape{migraphx::shape::float_type, {8}}; auto s = migraphx::shape{migraphx::shape::float_type, {8}};
...@@ -132,6 +290,8 @@ TEST_CASE(merge_case) ...@@ -132,6 +290,8 @@ TEST_CASE(merge_case)
----------------- -----------------
| |
Mul (tid = 0) Mul (tid = 0)
|
Return
*/ */
migraphx::target_assignments tass; migraphx::target_assignments tass;
...@@ -212,6 +372,8 @@ TEST_CASE(fork_and_merge_case) ...@@ -212,6 +372,8 @@ TEST_CASE(fork_and_merge_case)
---------------- ----------------
| |
Sub (tid = 0) Sub (tid = 0)
|
Return
*/ */
auto s = migraphx::shape{migraphx::shape::float_type, {8}}; auto s = migraphx::shape{migraphx::shape::float_type, {8}};
...@@ -295,4 +457,213 @@ TEST_CASE(fork_and_merge_case) ...@@ -295,4 +457,213 @@ TEST_CASE(fork_and_merge_case)
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
}; };
TEST_CASE(nested_if_then_else_program)
{
/*
If ----------------> Return
|
-----------------------------------------
| |
(then_mod) (else_mod)
| |
Add (tid = 3) Mul (tid = 2)
| |
If If
| |
---------------------- --------------------
| | | |
(then_mod) (else_mod) (then_mod) (else_mod)
| | | |
Add (tid = 1) Add (tid = 0) Add (tid = 0) Add (tid = 1)
*/
migraphx::shape ds{migraphx::shape::float_type, {2, 3}};
migraphx::shape cond_s{migraphx::shape::bool_type};
migraphx::target_assignments tass;
std::vector<float> data(ds.elements(), -1);
migraphx::program p1;
{
std::unordered_map<std::size_t, std::size_t> counter_map = {{0, 0}, {1, 0}};
auto* mm = p1.get_main_module();
auto cond_0 = mm->add_parameter("cond_0", cond_s);
auto cond_1 = mm->add_parameter("cond_1", cond_s);
auto x = mm->add_parameter("x", ds);
auto y = mm->add_parameter("y", ds);
auto z = mm->add_parameter("z", ds);
auto create_test_module = [&](migraphx::program& prog,
std::size_t tid,
std::string param_prefix) {
std::string mod_name =
"target_" + std::to_string(tid) + "_" + std::to_string(counter_map[tid]++);
auto* test_mod = prog.create_module(mod_name);
auto l1 = test_mod->add_literal(migraphx::literal(ds, data));
auto test_mod_param_0 = test_mod->add_parameter(param_prefix + "_param_0", ds);
auto ins1 = test_mod->add_instruction(migraphx::make_op("add"), test_mod_param_0, l1);
test_mod->add_return({ins1});
tass.insert(tass.begin(), std::make_pair(ins1, tid));
return test_mod;
};
auto* then_mod = p1.create_module("then_mod");
auto then_mod_cond = then_mod->add_parameter("then_mod_cond", cond_s);
auto then_mod_param_0 = then_mod->add_parameter("then_mod_param_0", ds);
auto then_mod_param_1 = then_mod->add_parameter("then_mod_param_1", ds);
auto then_mod_add_ins =
then_mod->add_instruction(migraphx::make_op("add"), then_mod_param_0, then_mod_param_1);
tass.insert(tass.begin(), std::make_pair(then_mod_add_ins, 3));
auto then_mod_if = then_mod->add_instruction(
migraphx::make_op("if"),
{then_mod_cond, then_mod_param_0, then_mod_add_ins},
{create_test_module(p1, 1, "1_"), create_test_module(p1, 0, "2_")});
auto then_mod_if_0 = then_mod->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), then_mod_if);
then_mod->add_return({then_mod_if_0});
auto* else_mod = p1.create_module("else_mod");
auto else_mod_cond = else_mod->add_parameter("else_mod_cond", cond_s);
auto else_mod_param_0 = else_mod->add_parameter("else_mod_param_0", ds);
auto else_mod_param_1 = else_mod->add_parameter("else_mod_param_1", ds);
auto else_mod_add_ins =
else_mod->add_instruction(migraphx::make_op("mul"), else_mod_param_0, else_mod_param_1);
tass.insert(tass.begin(), std::make_pair(else_mod_add_ins, 2));
auto else_mod_if = else_mod->add_instruction(
migraphx::make_op("if"),
{else_mod_cond, else_mod_add_ins, else_mod_param_0},
{create_test_module(p1, 0, "1_"), create_test_module(p1, 1, "2_")});
auto else_mod_if_0 = else_mod->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), else_mod_if);
else_mod->add_return({else_mod_if_0});
// Create nested and multi-target main module using "If"
auto main_if_ins = mm->add_instruction(
migraphx::make_op("if"), {cond_0, cond_1, x, y, cond_1, x, z}, {then_mod, else_mod});
auto r =
mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), main_if_ins);
mm->add_return({r});
}
migraphx::generate_root_modules(p1, tass);
migraphx::program p2;
{
std::unordered_map<std::size_t, std::size_t> counter_map = {{0, 0}, {1, 0}};
migraphx::module_ref mm = p2.get_main_module();
auto z = mm->add_parameter("z", ds);
auto y = mm->add_parameter("y", ds);
auto x = mm->add_parameter("x", ds);
auto cond_1 = mm->add_parameter("cond_1", cond_s);
auto cond_0 = mm->add_parameter("cond_0", cond_s);
auto create_test_module = [&](migraphx::program& prog, std::size_t tid) {
std::string mod_name =
"target_mod_" + std::to_string(tid) + "_" + std::to_string(counter_map[tid]++);
auto* test_mod = prog.create_module(mod_name);
auto test_mod_param_0 = test_mod->add_parameter("param:0", ds);
auto test_mod_param_1 = test_mod->add_parameter("param:1", ds);
auto ins1 = test_mod->add_instruction(
migraphx::make_op("add"), test_mod_param_1, test_mod_param_0);
test_mod->add_return({ins1});
tass.insert(tass.begin(), std::make_pair(ins1, tid));
return test_mod;
};
migraphx::module_ref target_1_0 = p2.create_module("target_1_0");
auto target_1_0_1_param_0 = target_1_0->add_literal(ds, data);
auto target_1_0_1_param_1 = target_1_0->add_parameter("1__param_0", ds);
auto x_target_1_0_2 =
target_1_0->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 1}}),
{target_1_0_1_param_0, target_1_0_1_param_1},
{create_test_module(p2, 1)});
auto x_target_1_0_3 = target_1_0->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_target_1_0_2);
target_1_0->add_return({x_target_1_0_3});
migraphx::module_ref target_0_0 = p2.create_module("target_0_0");
auto target_0_0_2_param_0 = target_0_0->add_literal(ds, data);
auto target_0_0_2_param_1 = target_0_0->add_parameter("2__param_0", ds);
auto x_target_0_0_2 =
target_0_0->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 0}}),
{target_0_0_2_param_0, target_0_0_2_param_1},
{create_test_module(p2, 0)});
auto x_target_0_0_3 = target_0_0->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_target_0_0_2);
target_0_0->add_return({x_target_0_0_3});
migraphx::module_ref target_3_0 = p2.create_module("target_mod_3_0");
auto target_mod_3_0_param_1 = target_3_0->add_parameter("param:1", ds);
auto target_mod_3_0_param_0 = target_3_0->add_parameter("param:0", ds);
auto target_3_add_ins = target_3_0->add_instruction(
migraphx::make_op("add"), target_mod_3_0_param_1, target_mod_3_0_param_0);
target_3_0->add_return({target_3_add_ins});
migraphx::module_ref then_mod = p2.create_module("then_mod");
auto then_mod_then_mod_param_1 = then_mod->add_parameter("then_mod_param_1", ds);
auto then_mod_then_mod_param_0 = then_mod->add_parameter("then_mod_param_0", ds);
auto then_mod_then_mod_cond = then_mod->add_parameter("then_mod_cond", cond_s);
auto x_then_mod_3 =
then_mod->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 3}}),
{then_mod_then_mod_param_1, then_mod_then_mod_param_0},
{target_3_0});
auto x_then_mod_4 = then_mod->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_then_mod_3);
auto x_then_mod_5 = then_mod->add_instruction(
migraphx::make_op("if"),
{then_mod_then_mod_cond, then_mod_then_mod_param_0, x_then_mod_4},
{target_1_0, target_0_0});
auto x_then_mod_6 = then_mod->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_then_mod_5);
then_mod->add_return({x_then_mod_6});
migraphx::module_ref target_0_1 = p2.create_module("target_0_1");
auto target_0_1_1_param_0 = target_0_1->add_literal(ds, data);
auto target_0_1_1_param_1 = target_0_1->add_parameter("1__param_0", ds);
auto x_target_0_1_2 =
target_0_1->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 0}}),
{target_0_1_1_param_0, target_0_1_1_param_1},
{create_test_module(p2, 0)});
auto x_target_0_1_3 = target_0_1->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_target_0_1_2);
target_0_1->add_return({x_target_0_1_3});
migraphx::module_ref target_1_1 = p2.create_module("target_1_1");
auto target_1_1_2_param_0 = target_1_1->add_literal(ds, data);
auto target_1_1_2_param_1 = target_1_1->add_parameter("2__param_0", ds);
auto x_target_1_1_2 =
target_1_1->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 1}}),
{target_1_1_2_param_0, target_1_1_2_param_1},
{create_test_module(p2, 1)});
auto x_target_1_1_3 = target_1_1->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_target_1_1_2);
target_1_1->add_return({x_target_1_1_3});
migraphx::module_ref target_2_0 = p2.create_module("target_mod_2_0");
auto target_mod_2_0_param_1 = target_2_0->add_parameter("param:1", ds);
auto target_mod_2_0_param_0 = target_2_0->add_parameter("param:0", ds);
auto target_2_mul_ins = target_2_0->add_instruction(
migraphx::make_op("mul"), target_mod_2_0_param_1, target_mod_2_0_param_0);
target_2_0->add_return({target_2_mul_ins});
migraphx::module_ref else_mod = p2.create_module("else_mod");
auto else_mod_else_mod_param_0 = else_mod->add_parameter("else_mod_param_0", ds);
auto else_mod_else_mod_param_1 = else_mod->add_parameter("else_mod_param_1", ds);
auto else_mod_else_mod_cond = else_mod->add_parameter("else_mod_cond", cond_s);
auto x_else_mod_3 =
else_mod->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 2}}),
{else_mod_else_mod_param_1, else_mod_else_mod_param_0},
{target_2_0});
auto x_else_mod_4 = else_mod->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_else_mod_3);
auto x_else_mod_5 = else_mod->add_instruction(
migraphx::make_op("if"),
{else_mod_else_mod_cond, x_else_mod_4, else_mod_else_mod_param_0},
{target_0_1, target_1_1});
auto x_else_mod_6 = else_mod->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_else_mod_5);
else_mod->add_return({x_else_mod_6});
auto x_5 = mm->add_instruction(
migraphx::make_op("if"), {cond_0, cond_1, x, y, cond_1, x, z}, {then_mod, else_mod});
auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5);
mm->add_return({x_6});
}
EXPECT(p1.sort() == p2.sort());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment