Commit bd12f16b authored by Umang Yadav's avatar Umang Yadav
Browse files

Fixes for merge and fork node

parent 34b9258a
...@@ -154,7 +154,8 @@ struct auto_gen_root_modules ...@@ -154,7 +154,8 @@ struct auto_gen_root_modules
{ {
const auto inputs = ins->inputs(); const auto inputs = ins->inputs();
if(std::any_of(inputs.begin(), inputs.end(), [&](auto input_ins) { if(std::any_of(inputs.begin(), inputs.end(), [&](auto input_ins) {
if(is_different_subgraph(input_ins, tid)) if(tass.find(input_ins) != tass.end() and
tass.at(ins) != tid.value_or(std::numeric_limits<std::size_t>::max()))
{ {
return true; return true;
} }
...@@ -168,7 +169,9 @@ struct auto_gen_root_modules ...@@ -168,7 +169,9 @@ struct auto_gen_root_modules
{ {
const auto outputs = ins->outputs(); const auto outputs = ins->outputs();
if(std::any_of(outputs.begin(), outputs.end(), [&](auto output_ins) { if(std::any_of(outputs.begin(), outputs.end(), [&](auto output_ins) {
if(is_different_subgraph(output_ins, tid) and output_ins->name() != "@return") if(tass.find(output_ins) != tass.end() and
tass.at(output_ins) != tid.value_or(std::numeric_limits<std::size_t>::max()) and
output_ins->name() != "@return")
{ {
return true; return true;
} }
...@@ -206,11 +209,7 @@ struct auto_gen_root_modules ...@@ -206,11 +209,7 @@ struct auto_gen_root_modules
} }
if(not current_tid.has_value()) if(not current_tid.has_value())
{ {
if(tass.find(ins) == tass.end()) if(tass.find(ins) != tass.end())
{
continue;
}
else
{ {
current_tid = std::make_optional<std::size_t>(tass.at(ins)); current_tid = std::make_optional<std::size_t>(tass.at(ins));
update_tid_counter(current_tid.value()); update_tid_counter(current_tid.value());
...@@ -244,11 +243,11 @@ struct auto_gen_root_modules ...@@ -244,11 +243,11 @@ struct auto_gen_root_modules
} }
else else
{ {
MIGRAPHX_THROW("Partition: this case shouldn't occur"); MIGRAPHX_THROW("GenerateRootModules: this case shouldn't occur");
} }
} }
if(skip_ins.find(ins) == skip_ins.end() and not ins->module_inputs().empty()) if(not ins->module_inputs().empty())
{ {
std::vector<instruction_ref> same_tid_ins_vec_copy = {}; std::vector<instruction_ref> same_tid_ins_vec_copy = {};
std::unordered_set<instruction_ref> same_tid_ins_set_copy = {}; std::unordered_set<instruction_ref> same_tid_ins_set_copy = {};
...@@ -308,12 +307,11 @@ struct auto_gen_root_modules ...@@ -308,12 +307,11 @@ struct auto_gen_root_modules
{ {
(*tmp)->debug_print(); (*tmp)->debug_print();
} }
std::cout << "\n return ins: \n"; std::cout << "return ins: \n";
for(auto tmp : iterator_for(return_ins)) for(auto tmp : iterator_for(return_ins))
{ {
(*tmp)->debug_print(); (*tmp)->debug_print();
} }
std::cout << "\n";
} }
auto* tmod = p.create_module("target_mod_" + std::to_string(current_tid) + "_" + auto* tmod = p.create_module("target_mod_" + std::to_string(current_tid) + "_" +
......
...@@ -134,6 +134,7 @@ bool check_compiled_program(const migraphx::program& p, ...@@ -134,6 +134,7 @@ bool check_compiled_program(const migraphx::program& p,
const std::vector<migraphx::target>& targets) const std::vector<migraphx::target>& targets)
{ {
auto mods = p.get_modules(); auto mods = p.get_modules();
bool rot_ins = false;
bool check_compiled = true; bool check_compiled = true;
for(const auto* mod : mods) for(const auto* mod : mods)
{ {
...@@ -141,6 +142,7 @@ bool check_compiled_program(const migraphx::program& p, ...@@ -141,6 +142,7 @@ bool check_compiled_program(const migraphx::program& p,
{ {
if(ins.name() == "run_on_target") if(ins.name() == "run_on_target")
{ {
rot_ins |= true;
auto* mod_input = ins.module_inputs().front(); auto* mod_input = ins.module_inputs().front();
std::size_t target_id = std::size_t target_id =
ins.get_operator().to_value()["target_id"].to<std::size_t>(); ins.get_operator().to_value()["target_id"].to<std::size_t>();
...@@ -156,7 +158,7 @@ bool check_compiled_program(const migraphx::program& p, ...@@ -156,7 +158,7 @@ bool check_compiled_program(const migraphx::program& p,
} }
} }
} }
return check_compiled; return check_compiled and rot_ins;
} }
TEST_CASE(multitarget_compile_cpu_gpu) TEST_CASE(multitarget_compile_cpu_gpu)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment