Commit 31b6db1a authored by umangyadav's avatar umangyadav
Browse files

if then else tests are working

parent b448194e
......@@ -21,9 +21,11 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "migraphx/target_assignments.hpp"
#include <cstddef>
#include <limits>
#include <iterator>
#include <unordered_map>
#include <unordered_set>
#include <migraphx/env.hpp>
......@@ -65,21 +67,20 @@ static literal get_scalar(instruction_ref ins)
}
return r;
}
void partition(migraphx::program& p, const target_assignments& tass)
void partition(migraphx::module_ref mm,
migraphx::program& p,
const target_assignments& tass,
std::unordered_map<std::size_t, std::size_t>& tid_counter)
{
auto* mm = p.get_main_module();
// sort the graph in reverse post order DFS order
mm->sort();
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{}))
{
std::cout << "sorted program: \n";
p.debug_print();
std::cout << "sorted module: \n";
mm->debug_print();
}
std::vector<instruction_ref> same_tid_ins_vec;
std::unordered_set<instruction_ref> same_tid_ins_set;
// walk the graph in reverse-DFS order
std::unordered_map<std::size_t, std::size_t> tid_counter;
size_t current_tid = std::numeric_limits<std::size_t>::max();
std::unordered_set<instruction_ref> skip_ins;
for(auto ins : iterator_for(*mm))
......@@ -94,12 +95,29 @@ void partition(migraphx::program& p, const target_assignments& tass)
ins->debug_print();
std::cout << "\n";
}
if(skip_ins.count(ins) == 0)
{
if(not ins->module_inputs().empty())
{
for(auto sub_mod : ins->module_inputs())
{
partition(sub_mod, p, tass, tid_counter);
}
mm->replace_instruction(
ins, ins->get_operator(), ins->inputs(), ins->module_inputs());
}
}
if((starts_with(ins->name(), "@") and ins->name() != "@return") or skip_ins.count(ins) != 0)
{
continue;
}
else if(ins->name() != "@return" and current_tid == std::numeric_limits<std::size_t>::max())
{
if(tass.find(ins) == tass.end())
{
continue;
}
current_tid = tass.at(ins);
tid_counter[current_tid] = 0;
same_tid_ins_vec.push_back(ins);
......@@ -110,7 +128,7 @@ void partition(migraphx::program& p, const target_assignments& tass)
same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins);
}
else
else if(ins->name() == "@return" or tass.at(ins) != current_tid)
{
// gather all parameters
std::unordered_set<instruction_ref> params;
......@@ -152,7 +170,8 @@ void partition(migraphx::program& p, const target_assignments& tass)
std::cout << "\n";
}
auto* tmod = p.create_module("target_mod:" + std::to_string(tid_counter[current_tid]));
auto* tmod = p.create_module("target_mod_" + std::to_string(current_tid) + "_" +
std::to_string(tid_counter[current_tid]));
std::unordered_map<instruction_ref, instruction_ref> params_map;
std::size_t param_counter = 0;
std::vector<instruction_ref> rot_ins_params;
......@@ -180,7 +199,8 @@ void partition(migraphx::program& p, const target_assignments& tass)
std::back_inserter(new_inputs),
[&](auto input_ins) { return params_map.at(input_ins); });
// [TODO]: what if it is has module args ?
auto tmod_tins = tmod->add_instruction((*tins)->get_operator(), new_inputs);
auto tmod_tins = tmod->add_instruction(
(*tins)->get_operator(), new_inputs, (*tins)->module_inputs());
// add parameter to params map so that it can be looked up by other insturctions
params_map[*tins] = tmod_tins;
}
......@@ -232,11 +252,19 @@ void partition(migraphx::program& p, const target_assignments& tass)
}
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{}))
{
std::cout << "program after creation of tmod and rot: \n";
p.debug_print();
std::cout << "module after creation of tmod and rot: \n";
mm->debug_print();
}
}
}
}
void partition(migraphx::program& p, const target_assignments& tass)
{
auto* mm = p.get_main_module();
// sort the graph in reverse post order DFS order
std::unordered_map<std::size_t, std::size_t> tid_counter;
partition(mm, p, tass, tid_counter);
dead_code_elimination{}.apply(p);
}
......
......@@ -256,24 +256,9 @@ TEST_CASE(multitarget_compile_if_then_else_partition)
auto a2 = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
else_mod->add_return({a2});
// auto* run_on_cpu_mod = p.create_module("run_on_cpu");
// auto run_cpu_ins = run_on_cpu_mod->add_instruction(
// migraphx::make_op("run_on_target", {{"target_id", 1}}), {y}, {else_mod});
// auto run_cpu_ins_0 = run_on_cpu_mod->add_instruction(
// migraphx::make_op("get_tuple_elem", {{"index", 0}}), run_cpu_ins);
// run_on_cpu_mod->add_return({run_cpu_ins_0});
// auto* run_on_gpu_mod = p.create_module("run_on_gpu");
// auto run_gpu_ins = run_on_gpu_mod->add_instruction(
// migraphx::make_op("run_on_target", {{"target_id", 0}}), {x}, {then_mod});
// auto run_gpu_ins_0 = run_on_gpu_mod->add_instruction(
// migraphx::make_op("get_tuple_elem", {{"index", 0}}), run_gpu_ins);
// run_on_gpu_mod->add_return({run_gpu_ins_0});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
p.debug_print();
migraphx::target_assignments tass;
tass.insert(tass.begin(), std::make_pair(l1, 0));
tass.insert(tass.begin(), std::make_pair(a1, 0));
......@@ -284,8 +269,13 @@ TEST_CASE(multitarget_compile_if_then_else_partition)
// compile
migraphx::compile_options gpu_opts;
gpu_opts.offload_copy = true;
p.compile({migraphx::make_target("gpu"), migraphx::make_target("cpu")}, {gpu_opts});
EXPECT(check_compiled_program(p, {migraphx::make_target("gpu"), migraphx::make_target("cpu")}));
p.compile(
{migraphx::make_target("gpu"), migraphx::make_target("cpu"), migraphx::make_target("ref")},
{gpu_opts});
EXPECT(check_compiled_program(p,
{migraphx::make_target("gpu"),
migraphx::make_target("cpu"),
migraphx::make_target("ref")}));
migraphx::parameter_map params;
params["x"] = migraphx::fill_argument(ds, 2);
params["y"] = migraphx::fill_argument(ds, 3);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment