Commit bd12f16b authored by Umang Yadav's avatar Umang Yadav
Browse files

Fixes for merge and fork node

parent 34b9258a
......@@ -154,7 +154,8 @@ struct auto_gen_root_modules
{
const auto inputs = ins->inputs();
if(std::any_of(inputs.begin(), inputs.end(), [&](auto input_ins) {
if(is_different_subgraph(input_ins, tid))
if(tass.find(input_ins) != tass.end() and
tass.at(ins) != tid.value_or(std::numeric_limits<std::size_t>::max()))
{
return true;
}
......@@ -168,7 +169,9 @@ struct auto_gen_root_modules
{
const auto outputs = ins->outputs();
if(std::any_of(outputs.begin(), outputs.end(), [&](auto output_ins) {
if(is_different_subgraph(output_ins, tid) and output_ins->name() != "@return")
if(tass.find(output_ins) != tass.end() and
tass.at(output_ins) != tid.value_or(std::numeric_limits<std::size_t>::max()) and
output_ins->name() != "@return")
{
return true;
}
......@@ -206,11 +209,7 @@ struct auto_gen_root_modules
}
if(not current_tid.has_value())
{
if(tass.find(ins) == tass.end())
{
continue;
}
else
if(tass.find(ins) != tass.end())
{
current_tid = std::make_optional<std::size_t>(tass.at(ins));
update_tid_counter(current_tid.value());
......@@ -244,11 +243,11 @@ struct auto_gen_root_modules
}
else
{
MIGRAPHX_THROW("Partition: this case shouldn't occur");
MIGRAPHX_THROW("GenerateRootModules: this case shouldn't occur");
}
}
if(skip_ins.find(ins) == skip_ins.end() and not ins->module_inputs().empty())
if(not ins->module_inputs().empty())
{
std::vector<instruction_ref> same_tid_ins_vec_copy = {};
std::unordered_set<instruction_ref> same_tid_ins_set_copy = {};
......@@ -308,12 +307,11 @@ struct auto_gen_root_modules
{
(*tmp)->debug_print();
}
std::cout << "\n return ins: \n";
std::cout << "return ins: \n";
for(auto tmp : iterator_for(return_ins))
{
(*tmp)->debug_print();
}
std::cout << "\n";
}
auto* tmod = p.create_module("target_mod_" + std::to_string(current_tid) + "_" +
......
......@@ -134,6 +134,7 @@ bool check_compiled_program(const migraphx::program& p,
const std::vector<migraphx::target>& targets)
{
auto mods = p.get_modules();
bool rot_ins = false;
bool check_compiled = true;
for(const auto* mod : mods)
{
......@@ -141,6 +142,7 @@ bool check_compiled_program(const migraphx::program& p,
{
if(ins.name() == "run_on_target")
{
rot_ins |= true;
auto* mod_input = ins.module_inputs().front();
std::size_t target_id =
ins.get_operator().to_value()["target_id"].to<std::size_t>();
......@@ -156,7 +158,7 @@ bool check_compiled_program(const migraphx::program& p,
}
}
}
return check_compiled;
return check_compiled and rot_ins;
}
TEST_CASE(multitarget_compile_cpu_gpu)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment