Commit 31b6db1a authored by umangyadav's avatar umangyadav
Browse files

if then else tests are working

parent b448194e
...@@ -21,9 +21,11 @@ ...@@ -21,9 +21,11 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "migraphx/target_assignments.hpp"
#include <cstddef> #include <cstddef>
#include <limits> #include <limits>
#include <iterator> #include <iterator>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
...@@ -65,21 +67,20 @@ static literal get_scalar(instruction_ref ins) ...@@ -65,21 +67,20 @@ static literal get_scalar(instruction_ref ins)
} }
return r; return r;
} }
void partition(migraphx::module_ref mm,
void partition(migraphx::program& p, const target_assignments& tass) migraphx::program& p,
const target_assignments& tass,
std::unordered_map<std::size_t, std::size_t>& tid_counter)
{ {
auto* mm = p.get_main_module();
// sort the graph in reverse post order DFS order
mm->sort(); mm->sort();
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{})) if(enabled(MIGRAPHX_DEBUG_PARTITIONER{}))
{ {
std::cout << "sorted program: \n"; std::cout << "sorted module: \n";
p.debug_print(); mm->debug_print();
} }
std::vector<instruction_ref> same_tid_ins_vec; std::vector<instruction_ref> same_tid_ins_vec;
std::unordered_set<instruction_ref> same_tid_ins_set; std::unordered_set<instruction_ref> same_tid_ins_set;
// walk the graph in reverse-DFS order // walk the graph in reverse-DFS order
std::unordered_map<std::size_t, std::size_t> tid_counter;
size_t current_tid = std::numeric_limits<std::size_t>::max(); size_t current_tid = std::numeric_limits<std::size_t>::max();
std::unordered_set<instruction_ref> skip_ins; std::unordered_set<instruction_ref> skip_ins;
for(auto ins : iterator_for(*mm)) for(auto ins : iterator_for(*mm))
...@@ -94,12 +95,29 @@ void partition(migraphx::program& p, const target_assignments& tass) ...@@ -94,12 +95,29 @@ void partition(migraphx::program& p, const target_assignments& tass)
ins->debug_print(); ins->debug_print();
std::cout << "\n"; std::cout << "\n";
} }
if(skip_ins.count(ins) == 0)
{
if(not ins->module_inputs().empty())
{
for(auto sub_mod : ins->module_inputs())
{
partition(sub_mod, p, tass, tid_counter);
}
mm->replace_instruction(
ins, ins->get_operator(), ins->inputs(), ins->module_inputs());
}
}
if((starts_with(ins->name(), "@") and ins->name() != "@return") or skip_ins.count(ins) != 0) if((starts_with(ins->name(), "@") and ins->name() != "@return") or skip_ins.count(ins) != 0)
{ {
continue; continue;
} }
else if(ins->name() != "@return" and current_tid == std::numeric_limits<std::size_t>::max()) else if(ins->name() != "@return" and current_tid == std::numeric_limits<std::size_t>::max())
{ {
if(tass.find(ins) == tass.end())
{
continue;
}
current_tid = tass.at(ins); current_tid = tass.at(ins);
tid_counter[current_tid] = 0; tid_counter[current_tid] = 0;
same_tid_ins_vec.push_back(ins); same_tid_ins_vec.push_back(ins);
...@@ -110,7 +128,7 @@ void partition(migraphx::program& p, const target_assignments& tass) ...@@ -110,7 +128,7 @@ void partition(migraphx::program& p, const target_assignments& tass)
same_tid_ins_vec.push_back(ins); same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins); same_tid_ins_set.insert(ins);
} }
else else if(ins->name() == "@return" or tass.at(ins) != current_tid)
{ {
// gather all parameters // gather all parameters
std::unordered_set<instruction_ref> params; std::unordered_set<instruction_ref> params;
...@@ -152,7 +170,8 @@ void partition(migraphx::program& p, const target_assignments& tass) ...@@ -152,7 +170,8 @@ void partition(migraphx::program& p, const target_assignments& tass)
std::cout << "\n"; std::cout << "\n";
} }
auto* tmod = p.create_module("target_mod:" + std::to_string(tid_counter[current_tid])); auto* tmod = p.create_module("target_mod_" + std::to_string(current_tid) + "_" +
std::to_string(tid_counter[current_tid]));
std::unordered_map<instruction_ref, instruction_ref> params_map; std::unordered_map<instruction_ref, instruction_ref> params_map;
std::size_t param_counter = 0; std::size_t param_counter = 0;
std::vector<instruction_ref> rot_ins_params; std::vector<instruction_ref> rot_ins_params;
...@@ -180,7 +199,8 @@ void partition(migraphx::program& p, const target_assignments& tass) ...@@ -180,7 +199,8 @@ void partition(migraphx::program& p, const target_assignments& tass)
std::back_inserter(new_inputs), std::back_inserter(new_inputs),
[&](auto input_ins) { return params_map.at(input_ins); }); [&](auto input_ins) { return params_map.at(input_ins); });
// [TODO]: what if it is has module args ? // [TODO]: what if it is has module args ?
auto tmod_tins = tmod->add_instruction((*tins)->get_operator(), new_inputs); auto tmod_tins = tmod->add_instruction(
(*tins)->get_operator(), new_inputs, (*tins)->module_inputs());
// add parameter to params map so that it can be looked up by other insturctions // add parameter to params map so that it can be looked up by other insturctions
params_map[*tins] = tmod_tins; params_map[*tins] = tmod_tins;
} }
...@@ -232,11 +252,19 @@ void partition(migraphx::program& p, const target_assignments& tass) ...@@ -232,11 +252,19 @@ void partition(migraphx::program& p, const target_assignments& tass)
} }
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{})) if(enabled(MIGRAPHX_DEBUG_PARTITIONER{}))
{ {
std::cout << "program after creation of tmod and rot: \n"; std::cout << "module after creation of tmod and rot: \n";
p.debug_print(); mm->debug_print();
} }
} }
} }
}
void partition(migraphx::program& p, const target_assignments& tass)
{
auto* mm = p.get_main_module();
// sort the graph in reverse post order DFS order
std::unordered_map<std::size_t, std::size_t> tid_counter;
partition(mm, p, tass, tid_counter);
dead_code_elimination{}.apply(p); dead_code_elimination{}.apply(p);
} }
......
...@@ -203,7 +203,7 @@ TEST_CASE(single_target_multi_compile) ...@@ -203,7 +203,7 @@ TEST_CASE(single_target_multi_compile)
auto score_threshold = mm->add_literal(0.0f); auto score_threshold = mm->add_literal(0.0f);
auto r = mm->add_instruction( auto r = mm->add_instruction(
migraphx::make_op("nonmaxsuppression", migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}), {{"center_point_box", true}, {"use_dyn_output", true}}),
boxes_param, boxes_param,
scores_l, scores_l,
max_out_l, max_out_l,
...@@ -256,24 +256,9 @@ TEST_CASE(multitarget_compile_if_then_else_partition) ...@@ -256,24 +256,9 @@ TEST_CASE(multitarget_compile_if_then_else_partition)
auto a2 = else_mod->add_instruction(migraphx::make_op("mul"), y, l2); auto a2 = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
else_mod->add_return({a2}); else_mod->add_return({a2});
// auto* run_on_cpu_mod = p.create_module("run_on_cpu");
// auto run_cpu_ins = run_on_cpu_mod->add_instruction(
// migraphx::make_op("run_on_target", {{"target_id", 1}}), {y}, {else_mod});
// auto run_cpu_ins_0 = run_on_cpu_mod->add_instruction(
// migraphx::make_op("get_tuple_elem", {{"index", 0}}), run_cpu_ins);
// run_on_cpu_mod->add_return({run_cpu_ins_0});
// auto* run_on_gpu_mod = p.create_module("run_on_gpu");
// auto run_gpu_ins = run_on_gpu_mod->add_instruction(
// migraphx::make_op("run_on_target", {{"target_id", 0}}), {x}, {then_mod});
// auto run_gpu_ins_0 = run_on_gpu_mod->add_instruction(
// migraphx::make_op("get_tuple_elem", {{"index", 0}}), run_gpu_ins);
// run_on_gpu_mod->add_return({run_gpu_ins_0});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r}); mm->add_return({r});
p.debug_print();
migraphx::target_assignments tass; migraphx::target_assignments tass;
tass.insert(tass.begin(), std::make_pair(l1, 0)); tass.insert(tass.begin(), std::make_pair(l1, 0));
tass.insert(tass.begin(), std::make_pair(a1, 0)); tass.insert(tass.begin(), std::make_pair(a1, 0));
...@@ -284,8 +269,13 @@ TEST_CASE(multitarget_compile_if_then_else_partition) ...@@ -284,8 +269,13 @@ TEST_CASE(multitarget_compile_if_then_else_partition)
// compile // compile
migraphx::compile_options gpu_opts; migraphx::compile_options gpu_opts;
gpu_opts.offload_copy = true; gpu_opts.offload_copy = true;
p.compile({migraphx::make_target("gpu"), migraphx::make_target("cpu")}, {gpu_opts}); p.compile(
EXPECT(check_compiled_program(p, {migraphx::make_target("gpu"), migraphx::make_target("cpu")})); {migraphx::make_target("gpu"), migraphx::make_target("cpu"), migraphx::make_target("ref")},
{gpu_opts});
EXPECT(check_compiled_program(p,
{migraphx::make_target("gpu"),
migraphx::make_target("cpu"),
migraphx::make_target("ref")}));
migraphx::parameter_map params; migraphx::parameter_map params;
params["x"] = migraphx::fill_argument(ds, 2); params["x"] = migraphx::fill_argument(ds, 2);
params["y"] = migraphx::fill_argument(ds, 3); params["y"] = migraphx::fill_argument(ds, 3);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment