"docs/source/experiment/vscode:/vscode.git/clone" did not exist on "fac7364a09d03f46a660aec7e6f3c911ed223336"
Commit 89e71273 authored by umang yadav's avatar umang yadav
Browse files

Partitioner working for nested if then else modules

parent 6c93676a
...@@ -67,6 +67,175 @@ static literal get_scalar(instruction_ref ins) ...@@ -67,6 +67,175 @@ static literal get_scalar(instruction_ref ins)
} }
return r; return r;
} }
void update_tid_counter(std::size_t tid, std::unordered_map<std::size_t, std::size_t>& tid_counter)
{
assert(tid != std::numeric_limits<std::size_t>::max());
if(tid_counter.find(tid) != tid_counter.end())
{
tid_counter[tid]++;
}
else
{
tid_counter[tid] = 0;
}
}
void generate_run_on_target_modules(migraphx::module_ref mm,
migraphx::program& p,
migraphx::instruction_ref ins,
std::size_t& current_tid,
const target_assignments& tass,
std::unordered_set<instruction_ref>& skip_ins,
std::unordered_map<std::size_t, std::size_t>& tid_counter,
std::vector<instruction_ref>& same_tid_ins_vec,
std::unordered_set<instruction_ref>& same_tid_ins_set)
{
assert(same_tid_ins_vec.size() == same_tid_ins_set.size());
if(same_tid_ins_vec.empty())
{
assert(current_tid == std::numeric_limits<std::size_t>::max());
return;
}
// gather all parameters
std::unordered_set<instruction_ref> params;
// gather all return values
std::unordered_set<instruction_ref> return_ins;
for(auto tins : iterator_for(same_tid_ins_vec))
{
auto inputs = (*tins)->inputs();
auto outputs = (*tins)->outputs();
transform_if(
inputs.cbegin(),
inputs.cend(),
std::inserter(params, params.end()),
[&](auto in_param) {
return (params.count(in_param) == 0 and same_tid_ins_set.count(in_param) == 0);
},
[&](auto in_param) { return in_param; });
if(std::any_of(outputs.begin(), outputs.end(), [&](const auto out_ins) {
return same_tid_ins_set.count(out_ins) == 0;
}))
{
return_ins.insert(*tins);
}
}
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{}))
{
std::cout << "params ins: \n";
for(auto tmp : iterator_for(params))
{
(*tmp)->debug_print();
}
std::cout << "\n";
std::cout << "return ins: \n";
for(auto tmp : iterator_for(return_ins))
{
(*tmp)->debug_print();
}
std::cout << "\n";
}
auto* tmod = p.create_module("target_mod_" + std::to_string(current_tid) + "_" +
std::to_string(tid_counter[current_tid]));
update_tid_counter(current_tid, tid_counter);
std::unordered_map<instruction_ref, instruction_ref> params_map;
std::size_t param_counter = 0;
std::vector<instruction_ref> rot_ins_params;
for(auto pins : iterator_for(params))
{
auto scalar = get_scalar(*pins);
if(scalar.empty())
{
params_map[*pins] = tmod->add_parameter("param:" + std::to_string(param_counter++),
(*pins)->get_shape());
rot_ins_params.push_back(*pins);
}
else
{
params_map[*pins] = tmod->add_literal(scalar);
}
}
// TODO: what if params_map is empty ?
for(auto tins : iterator_for(same_tid_ins_vec))
{
auto inputs = (*tins)->inputs();
std::vector<instruction_ref> new_inputs;
std::transform(inputs.begin(),
inputs.end(),
std::back_inserter(new_inputs),
[&](auto input_ins) { return params_map.at(input_ins); });
// [TODO]: what if it is has module args ?
auto tmod_tins =
tmod->add_instruction((*tins)->get_operator(), new_inputs, (*tins)->module_inputs());
// add parameter to params map so that it can be looked up by other insturctions
params_map[*tins] = tmod_tins;
}
std::vector<instruction_ref> rins;
std::unordered_map<instruction_ref, std::size_t> return_ins_idx_map;
for(auto ritr : iterator_for(return_ins))
{
rins.push_back(params_map.at(*ritr));
return_ins_idx_map[*ritr] = std::distance(ritr, return_ins.begin());
}
tmod->add_return(rins);
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{}))
{
std::cout << "tmod: \n";
tmod->debug_print();
}
// add run_on_target ins
auto rot_ins = mm->insert_instruction(
ins, make_op("run_on_target", {{"target_id", current_tid}}), rot_ins_params, {tmod});
skip_ins.insert(rot_ins);
// fetch return instructions from tuple
for(auto mm_rins : iterator_for(return_ins))
{
auto tuple_elem_ins = mm->insert_instruction(
ins, make_op("get_tuple_elem", {{"index", return_ins_idx_map.at(*mm_rins)}}), rot_ins);
skip_ins.insert(tuple_elem_ins);
// replace returns from tmod in main module
mm->replace_instruction(*mm_rins, tuple_elem_ins);
}
dead_code_elimination{}.apply(*mm);
// update current_tid
same_tid_ins_set.clear();
same_tid_ins_vec.clear();
if(tass.find(ins) != tass.end())
{
current_tid = tass.at(ins);
update_tid_counter(current_tid, tid_counter);
same_tid_ins_set.insert(ins);
same_tid_ins_vec.push_back(ins);
}
else
{
current_tid = std::numeric_limits<std::size_t>::max();
}
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{}))
{
std::cout << "module after creation of tmod and rot: \n";
mm->debug_print();
}
}
/*
Given target assignments (tass) for the instructions, generate run_on_target modules subgraphs
automatically. Input graph should be uncompiled migraphx program. target assignments (tass) map
should have a map of instruction to target_id. Instructions that are not inside tass map are
considered to be targeted for the "Ref" by default. params, literals and other builtins shouldn't be
part of the tass, only compute and reshape instructions should be part of tass. Copy, sync and alloc
instructions would be generated by compiler at later stage, so those shouldn't be considered.
(TODO): CustomOps may require special handling.
Identify subgraph boundaries, Ref is used for instructions that do not have assignments
1. Ref --> Target X --> Ref
2. Ref --> Target X --> Target 2
3. Target X --> Target Y --> Target Z , in this case target X and target Z can be same
4. Target X --> "@return"
5. Target X --> Ref --> "@return"
*/
void partition(migraphx::module_ref mm, void partition(migraphx::module_ref mm,
migraphx::program& p, migraphx::program& p,
const target_assignments& tass, const target_assignments& tass,
...@@ -107,154 +276,62 @@ void partition(migraphx::module_ref mm, ...@@ -107,154 +276,62 @@ void partition(migraphx::module_ref mm,
ins, ins->get_operator(), ins->inputs(), ins->module_inputs()); ins, ins->get_operator(), ins->inputs(), ins->module_inputs());
} }
} }
if(ins->name() == "@return")
if((starts_with(ins->name(), "@") and ins->name() != "@return") or skip_ins.count(ins) != 0) {
generate_run_on_target_modules(mm,
p,
ins,
current_tid,
tass,
skip_ins,
tid_counter,
same_tid_ins_vec,
same_tid_ins_set);
}
// skip all params, literal and builitins other than return, skip "run_on_target_mod" ins
else if(starts_with(ins->name(), "@") or skip_ins.count(ins) != 0)
{ {
continue; continue;
} }
else if(ins->name() != "@return" and current_tid == std::numeric_limits<std::size_t>::max()) else if(tass.find(ins) == tass.end())
{ {
if(tass.find(ins) == tass.end()) generate_run_on_target_modules(mm,
{ p,
continue; ins,
} current_tid,
current_tid = tass.at(ins); tass,
tid_counter[current_tid] = 0; skip_ins,
tid_counter,
same_tid_ins_vec,
same_tid_ins_set);
}
else if(current_tid == std::numeric_limits<std::size_t>::max())
{
current_tid = tass.at(ins);
update_tid_counter(current_tid, tid_counter);
same_tid_ins_vec.push_back(ins); same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins); same_tid_ins_set.insert(ins);
} }
else if(ins->name() != "@return" and tass.at(ins) == current_tid) else if(tass.at(ins) == current_tid)
{ {
same_tid_ins_vec.push_back(ins); same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins); same_tid_ins_set.insert(ins);
} }
else if(ins->name() == "@return" or tass.at(ins) != current_tid) else if(tass.at(ins) != current_tid)
{ {
// gather all parameters generate_run_on_target_modules(mm,
std::unordered_set<instruction_ref> params; p,
// gather all return values ins,
std::unordered_set<instruction_ref> return_ins; current_tid,
for(auto tins : iterator_for(same_tid_ins_vec)) tass,
{ skip_ins,
auto inputs = (*tins)->inputs(); tid_counter,
auto outputs = (*tins)->outputs(); same_tid_ins_vec,
transform_if( same_tid_ins_set);
inputs.cbegin(), }
inputs.cend(), else
std::inserter(params, params.end()), {
[&](auto in_param) { MIGRAPHX_THROW("Partition: this shouldn't occur");
return (params.count(in_param) == 0 and
same_tid_ins_set.count(in_param) == 0);
},
[&](auto in_param) { return in_param; });
if(std::any_of(outputs.begin(), outputs.end(), [&](const auto out_ins) {
return same_tid_ins_set.count(out_ins) == 0;
}))
{
return_ins.insert(*tins);
}
}
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{}))
{
std::cout << "params ins: \n";
for(auto tmp : iterator_for(params))
{
(*tmp)->debug_print();
}
std::cout << "\n";
std::cout << "return ins: \n";
for(auto tmp : iterator_for(return_ins))
{
(*tmp)->debug_print();
}
std::cout << "\n";
}
auto* tmod = p.create_module("target_mod_" + std::to_string(current_tid) + "_" +
std::to_string(tid_counter[current_tid]));
std::unordered_map<instruction_ref, instruction_ref> params_map;
std::size_t param_counter = 0;
std::vector<instruction_ref> rot_ins_params;
for(auto pins : iterator_for(params))
{
auto scalar = get_scalar(*pins);
if(scalar.empty())
{
params_map[*pins] = tmod->add_parameter(
"param:" + std::to_string(param_counter++), (*pins)->get_shape());
rot_ins_params.push_back(*pins);
}
else
{
params_map[*pins] = tmod->add_literal(scalar);
}
}
// TODO: what if params_map is empty ?
for(auto tins : iterator_for(same_tid_ins_vec))
{
auto inputs = (*tins)->inputs();
std::vector<instruction_ref> new_inputs;
std::transform(inputs.begin(),
inputs.end(),
std::back_inserter(new_inputs),
[&](auto input_ins) { return params_map.at(input_ins); });
// [TODO]: what if it is has module args ?
auto tmod_tins = tmod->add_instruction(
(*tins)->get_operator(), new_inputs, (*tins)->module_inputs());
// add parameter to params map so that it can be looked up by other insturctions
params_map[*tins] = tmod_tins;
}
std::vector<instruction_ref> rins;
std::unordered_map<instruction_ref, std::size_t> return_ins_idx_map;
for(auto ritr : iterator_for(return_ins))
{
rins.push_back(params_map.at(*ritr));
return_ins_idx_map[*ritr] = std::distance(ritr, return_ins.begin());
}
tmod->add_return(rins);
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{}))
{
std::cout << "tmod: \n";
tmod->debug_print();
}
// add run_on_target ins
auto rot_ins =
mm->insert_instruction(ins,
make_op("run_on_target", {{"target_id", current_tid}}),
rot_ins_params,
{tmod});
skip_ins.insert(rot_ins);
// fetch return instructions from tuple
for(auto mm_rins : iterator_for(return_ins))
{
auto tuple_elem_ins = mm->insert_instruction(
ins,
make_op("get_tuple_elem", {{"index", return_ins_idx_map.at(*mm_rins)}}),
rot_ins);
skip_ins.insert(tuple_elem_ins);
// replace returns from tmod in main module
mm->replace_instruction(*mm_rins, tuple_elem_ins);
}
dead_code_elimination{}.apply(*mm);
// update current_tid
if(ins->name() != "@return")
{
current_tid = tass.at(ins);
if(tid_counter.count(current_tid) == 0)
{
tid_counter[current_tid] = 0;
}
tid_counter[current_tid]++;
same_tid_ins_set.clear();
same_tid_ins_vec.clear();
same_tid_ins_set.insert(ins);
same_tid_ins_vec.push_back(ins);
}
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{}))
{
std::cout << "module after creation of tmod and rot: \n";
mm->debug_print();
}
} }
} }
} }
......
...@@ -223,7 +223,7 @@ TEST_CASE(single_target_multi_compile) ...@@ -223,7 +223,7 @@ TEST_CASE(single_target_multi_compile)
// eval // eval
migraphx::parameter_map params; migraphx::parameter_map params;
std::vector<float> boxes_vec = {0.5, 0.5, 1.0, 1.0, 0.5, 0.6, 1.0, 1.0, 0.5, 0.4, 1.0, 1.0, std::vector<float> boxes_vec = {0.5, 0.5, 1.0, 1.0, 0.5, 0.6, 1.0, 1.0, 0.5, 0.4, 1.0, 1.0,
0.5, 10.5, 1.0, 1.0, 0.5, 10.6, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0}; 0.5, 10.5, 1.0, 1.0, 0.5, 10.6, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0};
params["boxes"] = migraphx::argument(boxes_s, boxes_vec.data()); params["boxes"] = migraphx::argument(boxes_s, boxes_vec.data());
auto output = p.eval(params).back(); auto output = p.eval(params).back();
std::vector<int64_t> gold_vec = {0, 0, 3, 0, 0, 0, 0, 0, 5}; std::vector<int64_t> gold_vec = {0, 0, 3, 0, 0, 0, 0, 0, 5};
...@@ -510,6 +510,148 @@ TEST_CASE(multitarget_compile_nested_if_then_else) ...@@ -510,6 +510,148 @@ TEST_CASE(multitarget_compile_nested_if_then_else)
} }
} }
// TODO : FPGA compilation is broken right now, below test mentions fpga but doesn't compile for it
TEST_CASE(multitarget_compile_nested_if_then_else_partition)
{
std::unordered_map<std::size_t, std::size_t> counter_map = {{0, 0}, {1, 0}};
migraphx::shape ds{migraphx::shape::float_type, {2, 3}};
migraphx::target_assignments tass;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
auto cond_0 = mm->add_parameter("cond_0", cond_s);
auto cond_1 = mm->add_parameter("cond_1", cond_s);
auto x = mm->add_parameter("x", ds);
auto y = mm->add_parameter("y", ds);
auto z = mm->add_parameter("z", ds);
auto create_test_module = [&](migraphx::program& prog,
const std::vector<migraphx::instruction_ref>& inputs,
std::size_t tid) {
std::string mod_name =
"target_" + std::to_string(tid) + "_" + std::to_string(counter_map[tid]++);
auto* test_mod = prog.create_module(mod_name);
std::vector<float> data(ds.elements(), -1);
auto l1 = test_mod->add_literal(migraphx::literal(ds, data));
auto ins1 = test_mod->add_instruction(migraphx::make_op("add"), inputs[0], l1);
auto ins2 = test_mod->add_instruction(migraphx::make_op("mul"), ins1, inputs[1]);
auto ins3 = test_mod->add_instruction(migraphx::make_op("sub"), ins2, inputs[2]);
test_mod->add_return({ins3});
tass.insert(tass.begin(), std::make_pair(ins1, tid));
tass.insert(tass.begin(), std::make_pair(ins2, tid));
tass.insert(tass.begin(), std::make_pair(ins3, tid));
return test_mod;
};
auto* then_mod = p.create_module("then_mod");
auto then_mod_cond = then_mod->add_parameter("then_mod_cond", cond_s);
auto then_mod_param_0 = then_mod->add_parameter("then_mod_param_0", ds);
auto then_mod_param_1 = then_mod->add_parameter("then_mod_param_1", ds);
auto then_mod_param_2 = then_mod->add_parameter("then_mod_param_2", ds);
auto then_mod_ref_ins =
then_mod->add_instruction(migraphx::make_op("add"), then_mod_param_0, then_mod_param_1);
tass.insert(tass.begin(), std::make_pair(then_mod_ref_ins, 3));
auto then_mod_if = then_mod->add_instruction(
migraphx::make_op("if"),
{then_mod_cond,
then_mod_param_0,
then_mod_param_1,
then_mod_param_2,
then_mod_ref_ins,
then_mod_param_1,
then_mod_param_2},
{create_test_module(p, {then_mod_param_0, then_mod_param_1, then_mod_param_2}, 1),
create_test_module(p, {then_mod_ref_ins, then_mod_param_1, then_mod_param_2}, 0)});
auto then_mod_if_0 =
then_mod->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), then_mod_if);
then_mod->add_return({then_mod_if_0});
// create nested else_mod with multiple targets.
// else_mod has one instruction that runs a module on "fpga" and another instruction that
// creates nested modules using "If" that runs on "cpu" and "gpu"
auto* else_mod = p.create_module("else_mod");
auto else_mod_cond = else_mod->add_parameter("else_mod_cond", cond_s);
auto else_mod_param_0 = else_mod->add_parameter("else_mod_param_0", ds);
auto else_mod_param_1 = else_mod->add_parameter("else_mod_param_1", ds);
auto else_mod_param_2 = else_mod->add_parameter("else_mod_param_2", ds);
auto else_mod_fpga_ins =
else_mod->add_instruction(migraphx::make_op("add"), else_mod_param_0, else_mod_param_2);
tass.insert(tass.begin(), std::make_pair(else_mod_fpga_ins, 2));
auto else_mod_if = else_mod->add_instruction(
migraphx::make_op("if"),
{else_mod_cond,
else_mod_fpga_ins,
else_mod_param_0,
else_mod_param_1,
else_mod_param_2,
else_mod_param_1,
else_mod_param_0},
{create_test_module(p, {else_mod_fpga_ins, else_mod_param_0, else_mod_param_1}, 0),
create_test_module(p, {else_mod_param_2, else_mod_param_1, else_mod_param_0}, 1)});
auto else_mod_if_0 =
else_mod->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), else_mod_if);
else_mod->add_return({else_mod_if_0});
// Create nested and multi-target main module using "If"
auto main_if_ins = mm->add_instruction(
migraphx::make_op("if"), {cond_0, cond_1, x, y, z, cond_1, x, y, z}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), main_if_ins);
mm->add_return({r});
// compile
migraphx::compile_options gpu_opts;
gpu_opts.offload_copy = true;
std::cout << "before parition\n";
p.debug_print();
migraphx::partition(p, tass);
std::cout << "after partition\n";
p.debug_print();
p.compile({migraphx::make_target("gpu"),
migraphx::make_target("cpu"),
migraphx::make_target("ref"),
migraphx::make_target("ref")},
{gpu_opts});
std::cout << "after compilation\n";
p.debug_print();
EXPECT(check_compiled_program(p,
{migraphx::make_target("gpu"),
migraphx::make_target("cpu"),
migraphx::make_target("ref"),
migraphx::make_target("ref")}));
// do evaluation using different conditions
migraphx::parameter_map params;
float x_i = 2.0;
float y_i = 3.0;
float z_i = 4.0;
params["x"] = migraphx::fill_argument(ds, x_i);
params["y"] = migraphx::fill_argument(ds, y_i);
params["z"] = migraphx::fill_argument(ds, z_i);
// cover all paths with different combination of conditions
std::vector<std::pair<bool, bool>> test_conds = {
{true, true}, {true, false}, {false, true}, {false, false}};
for(auto [cond_val_0, cond_val_1] : test_conds)
{
params["cond_0"] = migraphx::argument(cond_s, &cond_val_0);
params["cond_1"] = migraphx::argument(cond_s, &cond_val_1);
auto result = p.eval(params).back();
// main has one instruction that is : if_then_else
// then mod is doing : {tmp = x+y; (cond) ? (((x-1)*y)-z) : (((tmp-1)*y)-z);}
// else mod is doing : {tmp = x+z; (cond) ? (((tmp-1)*x)-y) : (((z-1)*y)-x);}
float gold_i = -1.0;
if(cond_val_0)
{
float tmp_i = x_i + y_i;
gold_i = (cond_val_1) ? (((x_i - 1) * y_i) - z_i) : (((tmp_i - 1) * y_i) - z_i);
}
else
{
float tmp_i = x_i + z_i;
gold_i = (cond_val_1) ? (((tmp_i - 1) * x_i) - y_i) : (((z_i - 1) * y_i) - x_i);
}
auto gold = migraphx::fill_argument(ds, gold_i);
EXPECT(gold == result);
}
}
// TODO : FPGA compilation is broken right now, below test mentions fpga but doesn't compile for it // TODO : FPGA compilation is broken right now, below test mentions fpga but doesn't compile for it
TEST_CASE(multitarget_select_module) TEST_CASE(multitarget_select_module)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment