"docs/source/TrainingService/Overview.rst" did not exist on "3ec26b40afd88b8255ebc74b46f475b77aa4d19b"
Commit 56c4736a authored by umang yadav's avatar umang yadav
Browse files

add select module to tests

parent af0bfffd
...@@ -301,8 +301,7 @@ TEST_CASE(multitarget_compile_nested_if_then_else) ...@@ -301,8 +301,7 @@ TEST_CASE(multitarget_compile_nested_if_then_else)
auto x = mm->add_parameter("x", ds); auto x = mm->add_parameter("x", ds);
auto y = mm->add_parameter("y", ds); auto y = mm->add_parameter("y", ds);
auto z = mm->add_parameter("z", ds); auto z = mm->add_parameter("z", ds);
auto create_test_module = [&](migraphx::program& prog, auto create_test_module = [&](migraphx::program& prog, std::size_t tid) {
std::size_t tid) {
std::string mod_name = std::string mod_name =
"target_" + std::to_string(tid) + "_" + std::to_string(counter_map[tid]++); "target_" + std::to_string(tid) + "_" + std::to_string(counter_map[tid]++);
auto* test_mod = prog.create_module(mod_name); auto* test_mod = prog.create_module(mod_name);
...@@ -329,17 +328,16 @@ TEST_CASE(multitarget_compile_nested_if_then_else) ...@@ -329,17 +328,16 @@ TEST_CASE(multitarget_compile_nested_if_then_else)
auto then_mod_ref_ins = auto then_mod_ref_ins =
then_mod->add_instruction(migraphx::make_op("add"), then_mod_param_0, then_mod_param_1); then_mod->add_instruction(migraphx::make_op("add"), then_mod_param_0, then_mod_param_1);
tass.insert(tass.begin(), std::make_pair(then_mod_ref_ins, 3)); tass.insert(tass.begin(), std::make_pair(then_mod_ref_ins, 3));
auto then_mod_if = then_mod->add_instruction( auto then_mod_if =
migraphx::make_op("if"), then_mod->add_instruction(migraphx::make_op("if"),
{then_mod_cond, {then_mod_cond,
then_mod_param_0, then_mod_param_0,
then_mod_param_1, then_mod_param_1,
then_mod_param_2, then_mod_param_2,
then_mod_ref_ins, then_mod_ref_ins,
then_mod_param_1, then_mod_param_1,
then_mod_param_2}, then_mod_param_2},
{create_test_module(p, 1), {create_test_module(p, 1), create_test_module(p, 0)});
create_test_module(p, 0)});
auto then_mod_if_0 = auto then_mod_if_0 =
then_mod->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), then_mod_if); then_mod->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), then_mod_if);
then_mod->add_return({then_mod_if_0}); then_mod->add_return({then_mod_if_0});
...@@ -355,17 +353,16 @@ TEST_CASE(multitarget_compile_nested_if_then_else) ...@@ -355,17 +353,16 @@ TEST_CASE(multitarget_compile_nested_if_then_else)
auto else_mod_fpga_ins = auto else_mod_fpga_ins =
else_mod->add_instruction(migraphx::make_op("add"), else_mod_param_0, else_mod_param_2); else_mod->add_instruction(migraphx::make_op("add"), else_mod_param_0, else_mod_param_2);
tass.insert(tass.begin(), std::make_pair(else_mod_fpga_ins, 2)); tass.insert(tass.begin(), std::make_pair(else_mod_fpga_ins, 2));
auto else_mod_if = else_mod->add_instruction( auto else_mod_if =
migraphx::make_op("if"), else_mod->add_instruction(migraphx::make_op("if"),
{else_mod_cond, {else_mod_cond,
else_mod_fpga_ins, else_mod_fpga_ins,
else_mod_param_0, else_mod_param_0,
else_mod_param_1, else_mod_param_1,
else_mod_param_2, else_mod_param_2,
else_mod_param_1, else_mod_param_1,
else_mod_param_0}, else_mod_param_0},
{create_test_module(p, 0), {create_test_module(p, 0), create_test_module(p, 1)});
create_test_module(p, 1)});
auto else_mod_if_0 = auto else_mod_if_0 =
else_mod->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), else_mod_if); else_mod->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), else_mod_if);
else_mod->add_return({else_mod_if_0}); else_mod->add_return({else_mod_if_0});
...@@ -429,6 +426,7 @@ TEST_CASE(multitarget_compile_nested_if_then_else) ...@@ -429,6 +426,7 @@ TEST_CASE(multitarget_compile_nested_if_then_else)
TEST_CASE(multitarget_select_module) TEST_CASE(multitarget_select_module)
{ {
migraphx::program p; migraphx::program p;
migraphx::target_assignments tass;
// create batch submodules // create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) { auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p.create_module(module_name); auto* submod = p.create_module(module_name);
...@@ -440,6 +438,9 @@ TEST_CASE(multitarget_select_module) ...@@ -440,6 +438,9 @@ TEST_CASE(multitarget_select_module)
submod->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, sm_input); submod->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, sm_input);
auto add_ins0 = submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit); auto add_ins0 = submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
auto add_ins1 = submod->add_instruction(migraphx::make_op("add"), add_ins0, broadcast_lit); auto add_ins1 = submod->add_instruction(migraphx::make_op("add"), add_ins0, broadcast_lit);
tass.insert(tass.begin(), std::make_pair(broadcast_lit, batch_size - 1));
tass.insert(tass.begin(), std::make_pair(add_ins0, batch_size - 1));
tass.insert(tass.begin(), std::make_pair(add_ins1, batch_size - 1));
submod->add_return({add_ins1}); submod->add_return({add_ins1});
return submod; return submod;
}; };
...@@ -448,42 +449,6 @@ TEST_CASE(multitarget_select_module) ...@@ -448,42 +449,6 @@ TEST_CASE(multitarget_select_module)
auto* batch3 = create_submodule(3, "batch_3"); auto* batch3 = create_submodule(3, "batch_3");
auto* batch4 = create_submodule(4, "batch_4"); auto* batch4 = create_submodule(4, "batch_4");
auto* run_cpu_mod = p.create_module("cpu_mod");
auto cpu_param =
run_cpu_mod->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {1, 4}});
auto run_cpu_ins = run_cpu_mod->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 1}}), {cpu_param}, {batch1});
auto run_cpu_ins_0 = run_cpu_mod->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), run_cpu_ins);
run_cpu_mod->add_return({run_cpu_ins_0});
auto* run_gpu_mod = p.create_module("gpu_mod");
auto gpu_param =
run_gpu_mod->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {2, 4}});
auto run_gpu_ins = run_gpu_mod->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {gpu_param}, {batch2});
auto run_gpu_ins_0 = run_gpu_mod->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), run_gpu_ins);
run_gpu_mod->add_return({run_gpu_ins_0});
auto* run_fpga_mod = p.create_module("fpga_mod");
auto fpga_param =
run_fpga_mod->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto run_fpga_ins = run_fpga_mod->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 2}}), {fpga_param}, {batch3});
auto run_fpga_ins_0 = run_fpga_mod->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), run_fpga_ins);
run_fpga_mod->add_return({run_fpga_ins_0});
auto* run_ref_mod = p.create_module("ref_mod");
auto ref_param =
run_ref_mod->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {4, 4}});
auto run_ref_ins = run_ref_mod->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 3}}), {ref_param}, {batch4});
auto run_ref_ins_0 = run_ref_mod->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), run_ref_ins);
run_ref_mod->add_return({run_ref_ins_0});
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape dyn_s{migraphx::shape::float_type, {{1, 4}, {4, 4}}}; migraphx::shape dyn_s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input = mm->add_parameter("data", dyn_s); auto input = mm->add_parameter("data", dyn_s);
...@@ -494,12 +459,13 @@ TEST_CASE(multitarget_select_module) ...@@ -494,12 +459,13 @@ TEST_CASE(multitarget_select_module)
auto sm_ins = mm->add_instruction( auto sm_ins = mm->add_instruction(
migraphx::make_op("select_module", {{"output_dyn_shapes", migraphx::to_value(out_attr)}}), migraphx::make_op("select_module", {{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input}, {input},
{run_cpu_mod, run_gpu_mod, run_fpga_mod, run_ref_mod}); {batch1, batch2, batch3, batch4});
auto ret0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins); auto ret0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm->add_return({ret0}); mm->add_return({ret0});
// compile // compile
migraphx::compile_options gpu_opts; migraphx::compile_options gpu_opts;
gpu_opts.offload_copy = true; gpu_opts.offload_copy = true;
migraphx::generate_root_modules(p, tass);
p.compile({migraphx::make_target("gpu"), p.compile({migraphx::make_target("gpu"),
migraphx::make_target("cpu"), migraphx::make_target("cpu"),
migraphx::make_target("ref"), migraphx::make_target("ref"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment