"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "6416f066e7642b418d5ffb3637167c66aaad8627"
Commit 89e71273 authored by umang yadav's avatar umang yadav
Browse files

Partitioner working for nested if then else modules

parent 6c93676a
...@@ -67,69 +67,36 @@ static literal get_scalar(instruction_ref ins) ...@@ -67,69 +67,36 @@ static literal get_scalar(instruction_ref ins)
} }
return r; return r;
} }
void partition(migraphx::module_ref mm,
migraphx::program& p, void update_tid_counter(std::size_t tid, std::unordered_map<std::size_t, std::size_t>& tid_counter)
const target_assignments& tass,
std::unordered_map<std::size_t, std::size_t>& tid_counter)
{ {
mm->sort(); assert(tid != std::numeric_limits<std::size_t>::max());
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{})) if(tid_counter.find(tid) != tid_counter.end())
{
std::cout << "sorted module: \n";
mm->debug_print();
}
std::vector<instruction_ref> same_tid_ins_vec;
std::unordered_set<instruction_ref> same_tid_ins_set;
// walk the graph in reverse-DFS order
size_t current_tid = std::numeric_limits<std::size_t>::max();
std::unordered_set<instruction_ref> skip_ins;
for(auto ins : iterator_for(*mm))
{ {
// gather instructions belonging to the same target_id tid_counter[tid]++;
// for now, make sure that all the inputs to the insturctions are also from the same
// target_id, if not create another module
// skip all the builtins
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{}))
{
std::cout << "currently processing: \n";
ins->debug_print();
std::cout << "\n";
} }
if(skip_ins.count(ins) == 0) else
{
if(not ins->module_inputs().empty())
{
for(auto sub_mod : ins->module_inputs())
{ {
partition(sub_mod, p, tass, tid_counter); tid_counter[tid] = 0;
}
mm->replace_instruction(
ins, ins->get_operator(), ins->inputs(), ins->module_inputs());
}
} }
}
if((starts_with(ins->name(), "@") and ins->name() != "@return") or skip_ins.count(ins) != 0) void generate_run_on_target_modules(migraphx::module_ref mm,
{ migraphx::program& p,
continue; migraphx::instruction_ref ins,
} std::size_t& current_tid,
else if(ins->name() != "@return" and current_tid == std::numeric_limits<std::size_t>::max()) const target_assignments& tass,
{ std::unordered_set<instruction_ref>& skip_ins,
if(tass.find(ins) == tass.end()) std::unordered_map<std::size_t, std::size_t>& tid_counter,
{ std::vector<instruction_ref>& same_tid_ins_vec,
continue; std::unordered_set<instruction_ref>& same_tid_ins_set)
} {
current_tid = tass.at(ins); assert(same_tid_ins_vec.size() == same_tid_ins_set.size());
tid_counter[current_tid] = 0; if(same_tid_ins_vec.empty())
same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins);
}
else if(ins->name() != "@return" and tass.at(ins) == current_tid)
{ {
same_tid_ins_vec.push_back(ins); assert(current_tid == std::numeric_limits<std::size_t>::max());
same_tid_ins_set.insert(ins); return;
} }
else if(ins->name() == "@return" or tass.at(ins) != current_tid)
{
// gather all parameters // gather all parameters
std::unordered_set<instruction_ref> params; std::unordered_set<instruction_ref> params;
// gather all return values // gather all return values
...@@ -143,8 +110,7 @@ void partition(migraphx::module_ref mm, ...@@ -143,8 +110,7 @@ void partition(migraphx::module_ref mm,
inputs.cend(), inputs.cend(),
std::inserter(params, params.end()), std::inserter(params, params.end()),
[&](auto in_param) { [&](auto in_param) {
return (params.count(in_param) == 0 and return (params.count(in_param) == 0 and same_tid_ins_set.count(in_param) == 0);
same_tid_ins_set.count(in_param) == 0);
}, },
[&](auto in_param) { return in_param; }); [&](auto in_param) { return in_param; });
if(std::any_of(outputs.begin(), outputs.end(), [&](const auto out_ins) { if(std::any_of(outputs.begin(), outputs.end(), [&](const auto out_ins) {
...@@ -172,6 +138,7 @@ void partition(migraphx::module_ref mm, ...@@ -172,6 +138,7 @@ void partition(migraphx::module_ref mm,
auto* tmod = p.create_module("target_mod_" + std::to_string(current_tid) + "_" + auto* tmod = p.create_module("target_mod_" + std::to_string(current_tid) + "_" +
std::to_string(tid_counter[current_tid])); std::to_string(tid_counter[current_tid]));
update_tid_counter(current_tid, tid_counter);
std::unordered_map<instruction_ref, instruction_ref> params_map; std::unordered_map<instruction_ref, instruction_ref> params_map;
std::size_t param_counter = 0; std::size_t param_counter = 0;
std::vector<instruction_ref> rot_ins_params; std::vector<instruction_ref> rot_ins_params;
...@@ -180,8 +147,8 @@ void partition(migraphx::module_ref mm, ...@@ -180,8 +147,8 @@ void partition(migraphx::module_ref mm,
auto scalar = get_scalar(*pins); auto scalar = get_scalar(*pins);
if(scalar.empty()) if(scalar.empty())
{ {
params_map[*pins] = tmod->add_parameter( params_map[*pins] = tmod->add_parameter("param:" + std::to_string(param_counter++),
"param:" + std::to_string(param_counter++), (*pins)->get_shape()); (*pins)->get_shape());
rot_ins_params.push_back(*pins); rot_ins_params.push_back(*pins);
} }
else else
...@@ -199,8 +166,8 @@ void partition(migraphx::module_ref mm, ...@@ -199,8 +166,8 @@ void partition(migraphx::module_ref mm,
std::back_inserter(new_inputs), std::back_inserter(new_inputs),
[&](auto input_ins) { return params_map.at(input_ins); }); [&](auto input_ins) { return params_map.at(input_ins); });
// [TODO]: what if it is has module args ? // [TODO]: what if it is has module args ?
auto tmod_tins = tmod->add_instruction( auto tmod_tins =
(*tins)->get_operator(), new_inputs, (*tins)->module_inputs()); tmod->add_instruction((*tins)->get_operator(), new_inputs, (*tins)->module_inputs());
// add parameter to params map so that it can be looked up by other insturctions // add parameter to params map so that it can be looked up by other insturctions
params_map[*tins] = tmod_tins; params_map[*tins] = tmod_tins;
} }
...@@ -218,43 +185,153 @@ void partition(migraphx::module_ref mm, ...@@ -218,43 +185,153 @@ void partition(migraphx::module_ref mm,
tmod->debug_print(); tmod->debug_print();
} }
// add run_on_target ins // add run_on_target ins
auto rot_ins = auto rot_ins = mm->insert_instruction(
mm->insert_instruction(ins, ins, make_op("run_on_target", {{"target_id", current_tid}}), rot_ins_params, {tmod});
make_op("run_on_target", {{"target_id", current_tid}}),
rot_ins_params,
{tmod});
skip_ins.insert(rot_ins); skip_ins.insert(rot_ins);
// fetch return instructions from tuple // fetch return instructions from tuple
for(auto mm_rins : iterator_for(return_ins)) for(auto mm_rins : iterator_for(return_ins))
{ {
auto tuple_elem_ins = mm->insert_instruction( auto tuple_elem_ins = mm->insert_instruction(
ins, ins, make_op("get_tuple_elem", {{"index", return_ins_idx_map.at(*mm_rins)}}), rot_ins);
make_op("get_tuple_elem", {{"index", return_ins_idx_map.at(*mm_rins)}}),
rot_ins);
skip_ins.insert(tuple_elem_ins); skip_ins.insert(tuple_elem_ins);
// replace returns from tmod in main module // replace returns from tmod in main module
mm->replace_instruction(*mm_rins, tuple_elem_ins); mm->replace_instruction(*mm_rins, tuple_elem_ins);
} }
dead_code_elimination{}.apply(*mm); dead_code_elimination{}.apply(*mm);
// update current_tid // update current_tid
if(ins->name() != "@return")
{
current_tid = tass.at(ins);
if(tid_counter.count(current_tid) == 0)
{
tid_counter[current_tid] = 0;
}
tid_counter[current_tid]++;
same_tid_ins_set.clear(); same_tid_ins_set.clear();
same_tid_ins_vec.clear(); same_tid_ins_vec.clear();
if(tass.find(ins) != tass.end())
{
current_tid = tass.at(ins);
update_tid_counter(current_tid, tid_counter);
same_tid_ins_set.insert(ins); same_tid_ins_set.insert(ins);
same_tid_ins_vec.push_back(ins); same_tid_ins_vec.push_back(ins);
} }
else
{
current_tid = std::numeric_limits<std::size_t>::max();
}
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{})) if(enabled(MIGRAPHX_DEBUG_PARTITIONER{}))
{ {
std::cout << "module after creation of tmod and rot: \n"; std::cout << "module after creation of tmod and rot: \n";
mm->debug_print(); mm->debug_print();
} }
}
/*
Given target assignments (tass) for the instructions, generate run_on_target modules subgraphs
automatically. Input graph should be uncompiled migraphx program. target assignments (tass) map
should have a map of instruction to target_id. Instructions that are not inside tass map are
considered to be targeted for the "Ref" by default. params, literals and other builtins shouldn't be
part of the tass, only compute and reshape instructions should be part of tass. Copy, sync and alloc
instructions would be generated by compiler at later stage, so those shouldn't be considered.
(TODO): CustomOps may require special handling.
Identify subgraph boundaries, Ref is used for instructions that do not have assignments
1. Ref --> Target X --> Ref
2. Ref --> Target X --> Target 2
3. Target X --> Target Y --> Target Z , in this case target X and target Z can be same
4. Target X --> "@return"
5. Target X --> Ref --> "@return"
*/
void partition(migraphx::module_ref mm,
migraphx::program& p,
const target_assignments& tass,
std::unordered_map<std::size_t, std::size_t>& tid_counter)
{
mm->sort();
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{}))
{
std::cout << "sorted module: \n";
mm->debug_print();
}
std::vector<instruction_ref> same_tid_ins_vec;
std::unordered_set<instruction_ref> same_tid_ins_set;
// walk the graph in reverse-DFS order
size_t current_tid = std::numeric_limits<std::size_t>::max();
std::unordered_set<instruction_ref> skip_ins;
for(auto ins : iterator_for(*mm))
{
// gather instructions belonging to the same target_id
// for now, make sure that all the inputs to the insturctions are also from the same
// target_id, if not create another module
// skip all the builtins
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{}))
{
std::cout << "currently processing: \n";
ins->debug_print();
std::cout << "\n";
}
if(skip_ins.count(ins) == 0)
{
if(not ins->module_inputs().empty())
{
for(auto sub_mod : ins->module_inputs())
{
partition(sub_mod, p, tass, tid_counter);
}
mm->replace_instruction(
ins, ins->get_operator(), ins->inputs(), ins->module_inputs());
}
}
if(ins->name() == "@return")
{
generate_run_on_target_modules(mm,
p,
ins,
current_tid,
tass,
skip_ins,
tid_counter,
same_tid_ins_vec,
same_tid_ins_set);
}
// skip all params, literal and builitins other than return, skip "run_on_target_mod" ins
else if(starts_with(ins->name(), "@") or skip_ins.count(ins) != 0)
{
continue;
}
else if(tass.find(ins) == tass.end())
{
generate_run_on_target_modules(mm,
p,
ins,
current_tid,
tass,
skip_ins,
tid_counter,
same_tid_ins_vec,
same_tid_ins_set);
}
else if(current_tid == std::numeric_limits<std::size_t>::max())
{
current_tid = tass.at(ins);
update_tid_counter(current_tid, tid_counter);
same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins);
}
else if(tass.at(ins) == current_tid)
{
same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins);
}
else if(tass.at(ins) != current_tid)
{
generate_run_on_target_modules(mm,
p,
ins,
current_tid,
tass,
skip_ins,
tid_counter,
same_tid_ins_vec,
same_tid_ins_set);
}
else
{
MIGRAPHX_THROW("Partition: this shouldn't occur");
} }
} }
} }
......
...@@ -510,6 +510,148 @@ TEST_CASE(multitarget_compile_nested_if_then_else) ...@@ -510,6 +510,148 @@ TEST_CASE(multitarget_compile_nested_if_then_else)
} }
} }
// TODO : FPGA compilation is broken right now, below test mentions fpga but doesn't compile for it
TEST_CASE(multitarget_compile_nested_if_then_else_partition)
{
std::unordered_map<std::size_t, std::size_t> counter_map = {{0, 0}, {1, 0}};
migraphx::shape ds{migraphx::shape::float_type, {2, 3}};
migraphx::target_assignments tass;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
auto cond_0 = mm->add_parameter("cond_0", cond_s);
auto cond_1 = mm->add_parameter("cond_1", cond_s);
auto x = mm->add_parameter("x", ds);
auto y = mm->add_parameter("y", ds);
auto z = mm->add_parameter("z", ds);
auto create_test_module = [&](migraphx::program& prog,
const std::vector<migraphx::instruction_ref>& inputs,
std::size_t tid) {
std::string mod_name =
"target_" + std::to_string(tid) + "_" + std::to_string(counter_map[tid]++);
auto* test_mod = prog.create_module(mod_name);
std::vector<float> data(ds.elements(), -1);
auto l1 = test_mod->add_literal(migraphx::literal(ds, data));
auto ins1 = test_mod->add_instruction(migraphx::make_op("add"), inputs[0], l1);
auto ins2 = test_mod->add_instruction(migraphx::make_op("mul"), ins1, inputs[1]);
auto ins3 = test_mod->add_instruction(migraphx::make_op("sub"), ins2, inputs[2]);
test_mod->add_return({ins3});
tass.insert(tass.begin(), std::make_pair(ins1, tid));
tass.insert(tass.begin(), std::make_pair(ins2, tid));
tass.insert(tass.begin(), std::make_pair(ins3, tid));
return test_mod;
};
auto* then_mod = p.create_module("then_mod");
auto then_mod_cond = then_mod->add_parameter("then_mod_cond", cond_s);
auto then_mod_param_0 = then_mod->add_parameter("then_mod_param_0", ds);
auto then_mod_param_1 = then_mod->add_parameter("then_mod_param_1", ds);
auto then_mod_param_2 = then_mod->add_parameter("then_mod_param_2", ds);
auto then_mod_ref_ins =
then_mod->add_instruction(migraphx::make_op("add"), then_mod_param_0, then_mod_param_1);
tass.insert(tass.begin(), std::make_pair(then_mod_ref_ins, 3));
auto then_mod_if = then_mod->add_instruction(
migraphx::make_op("if"),
{then_mod_cond,
then_mod_param_0,
then_mod_param_1,
then_mod_param_2,
then_mod_ref_ins,
then_mod_param_1,
then_mod_param_2},
{create_test_module(p, {then_mod_param_0, then_mod_param_1, then_mod_param_2}, 1),
create_test_module(p, {then_mod_ref_ins, then_mod_param_1, then_mod_param_2}, 0)});
auto then_mod_if_0 =
then_mod->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), then_mod_if);
then_mod->add_return({then_mod_if_0});
// create nested else_mod with multiple targets.
// else_mod has one instruction that runs a module on "fpga" and another instruction that
// creates nested modules using "If" that runs on "cpu" and "gpu"
auto* else_mod = p.create_module("else_mod");
auto else_mod_cond = else_mod->add_parameter("else_mod_cond", cond_s);
auto else_mod_param_0 = else_mod->add_parameter("else_mod_param_0", ds);
auto else_mod_param_1 = else_mod->add_parameter("else_mod_param_1", ds);
auto else_mod_param_2 = else_mod->add_parameter("else_mod_param_2", ds);
auto else_mod_fpga_ins =
else_mod->add_instruction(migraphx::make_op("add"), else_mod_param_0, else_mod_param_2);
tass.insert(tass.begin(), std::make_pair(else_mod_fpga_ins, 2));
auto else_mod_if = else_mod->add_instruction(
migraphx::make_op("if"),
{else_mod_cond,
else_mod_fpga_ins,
else_mod_param_0,
else_mod_param_1,
else_mod_param_2,
else_mod_param_1,
else_mod_param_0},
{create_test_module(p, {else_mod_fpga_ins, else_mod_param_0, else_mod_param_1}, 0),
create_test_module(p, {else_mod_param_2, else_mod_param_1, else_mod_param_0}, 1)});
auto else_mod_if_0 =
else_mod->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), else_mod_if);
else_mod->add_return({else_mod_if_0});
// Create nested and multi-target main module using "If"
auto main_if_ins = mm->add_instruction(
migraphx::make_op("if"), {cond_0, cond_1, x, y, z, cond_1, x, y, z}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), main_if_ins);
mm->add_return({r});
// compile
migraphx::compile_options gpu_opts;
gpu_opts.offload_copy = true;
std::cout << "before parition\n";
p.debug_print();
migraphx::partition(p, tass);
std::cout << "after partition\n";
p.debug_print();
p.compile({migraphx::make_target("gpu"),
migraphx::make_target("cpu"),
migraphx::make_target("ref"),
migraphx::make_target("ref")},
{gpu_opts});
std::cout << "after compilation\n";
p.debug_print();
EXPECT(check_compiled_program(p,
{migraphx::make_target("gpu"),
migraphx::make_target("cpu"),
migraphx::make_target("ref"),
migraphx::make_target("ref")}));
// do evaluation using different conditions
migraphx::parameter_map params;
float x_i = 2.0;
float y_i = 3.0;
float z_i = 4.0;
params["x"] = migraphx::fill_argument(ds, x_i);
params["y"] = migraphx::fill_argument(ds, y_i);
params["z"] = migraphx::fill_argument(ds, z_i);
// cover all paths with different combination of conditions
std::vector<std::pair<bool, bool>> test_conds = {
{true, true}, {true, false}, {false, true}, {false, false}};
for(auto [cond_val_0, cond_val_1] : test_conds)
{
params["cond_0"] = migraphx::argument(cond_s, &cond_val_0);
params["cond_1"] = migraphx::argument(cond_s, &cond_val_1);
auto result = p.eval(params).back();
// main has one instruction that is : if_then_else
// then mod is doing : {tmp = x+y; (cond) ? (((x-1)*y)-z) : (((tmp-1)*y)-z);}
// else mod is doing : {tmp = x+z; (cond) ? (((tmp-1)*x)-y) : (((z-1)*y)-x);}
float gold_i = -1.0;
if(cond_val_0)
{
float tmp_i = x_i + y_i;
gold_i = (cond_val_1) ? (((x_i - 1) * y_i) - z_i) : (((tmp_i - 1) * y_i) - z_i);
}
else
{
float tmp_i = x_i + z_i;
gold_i = (cond_val_1) ? (((tmp_i - 1) * x_i) - y_i) : (((z_i - 1) * y_i) - x_i);
}
auto gold = migraphx::fill_argument(ds, gold_i);
EXPECT(gold == result);
}
}
// TODO : FPGA compilation is broken right now, below test mentions fpga but doesn't compile for it // TODO : FPGA compilation is broken right now, below test mentions fpga but doesn't compile for it
TEST_CASE(multitarget_select_module) TEST_CASE(multitarget_select_module)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment