"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "4d46cbdb1a82ca923802e8a33ad89e14e1f02e78"
Commit 8b2ee166 authored by Umang Yadav's avatar Umang Yadav
Browse files

Fork and merge cases working

parent 8db527c7
...@@ -152,29 +152,41 @@ struct auto_gen_root_modules ...@@ -152,29 +152,41 @@ struct auto_gen_root_modules
bool is_merge_node(migraphx::instruction_ref ins, std::optional<std::size_t> tid) bool is_merge_node(migraphx::instruction_ref ins, std::optional<std::size_t> tid)
{ {
const auto inputs = ins->inputs(); const auto inputs = ins->inputs();
return std::any_of(inputs.begin(), inputs.end(), [&](auto input_ins) { if(inputs.size() == 1)
if((skip_ins.find(input_ins) != skip_ins.end()) or {
(tass.find(input_ins) != tass.end() and
tass.at(ins) != tid.value_or(std::numeric_limits<std::size_t>::max())))
{
return true;
}
return false; return false;
}); }
if(std::any_of(inputs.begin(), inputs.end(), [&](auto input_ins) {
if((skip_ins.find(input_ins) != skip_ins.end()) or
(tass.find(input_ins) != tass.end() and
tass.at(ins) != tid.value_or(std::numeric_limits<std::size_t>::max())))
{
return true;
}
return false;
}))
return true;
return false;
} }
bool is_fork_node(migraphx::instruction_ref ins, std::optional<std::size_t> tid) bool is_fork_node(migraphx::instruction_ref ins, std::optional<std::size_t> tid)
{ {
const auto outputs = ins->outputs(); const auto outputs = ins->outputs();
return std::any_of(outputs.begin(), outputs.end(), [&](auto output_ins) { if(outputs.size() == 1)
if(tass.find(output_ins) != tass.end() and {
tass.at(output_ins) != tid.value_or(std::numeric_limits<std::size_t>::max()) and
output_ins->name() != "@return")
{
return true;
}
return false; return false;
}); }
if(std::any_of(outputs.begin(), outputs.end(), [&](auto output_ins) {
if(tass.find(output_ins) != tass.end() and
tass.at(output_ins) != tid.value_or(std::numeric_limits<std::size_t>::max()) and
output_ins->name() != "@return")
{
return true;
}
return false;
}))
return true;
return false;
} }
void find_subgraphs(migraphx::module_ref mm, migraphx::program& p) void find_subgraphs(migraphx::module_ref mm, migraphx::program& p)
...@@ -186,17 +198,28 @@ struct auto_gen_root_modules ...@@ -186,17 +198,28 @@ struct auto_gen_root_modules
std::cout << "sorted module: \n"; std::cout << "sorted module: \n";
mm->debug_print(); mm->debug_print();
} }
bool fork_node = false;
std::optional<std::size_t> current_tid = nullopt; std::optional<std::size_t> current_tid = nullopt;
for(auto ins : iterator_for(*mm)) for(auto ins : iterator_for(*mm))
{ {
if(enabled(MIGRAPHX_DEBUG_ROOT_GENERATOR{})) if(enabled(MIGRAPHX_DEBUG_ROOT_GENERATOR{}))
{ {
std::cout << "looking at instruction: \n"; std::cout << "looking at instruction: \n";
std::cout << "ins->name() == " << ins->name() << std::endl;
ins->debug_print(); ins->debug_print();
} }
if(fork_node)
{
std::cout << "found fork node\n";
assert(current_tid.has_value());
generate_run_on_target_modules(mm, p, ins, current_tid.value());
if(not same_tid_ins_vec.empty())
{
current_tid = nullopt;
same_tid_ins_set.erase(ins);
same_tid_ins_vec.pop_back();
}
fork_node = false;
}
// skip all params, literal and builtins other than return, skip "run_on_target_mod" // skip all params, literal and builtins other than return, skip "run_on_target_mod"
// ins // ins
if((starts_with(ins->name(), "@") and ins->name() != "@return") or if((starts_with(ins->name(), "@") and ins->name() != "@return") or
...@@ -212,18 +235,7 @@ struct auto_gen_root_modules ...@@ -212,18 +235,7 @@ struct auto_gen_root_modules
update_tid_counter(current_tid.value()); update_tid_counter(current_tid.value());
same_tid_ins_vec.push_back(ins); same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins); same_tid_ins_set.insert(ins);
if(is_fork_node(ins, current_tid)) fork_node = is_fork_node(ins, current_tid);
{
generate_run_on_target_modules(mm, p, std::next(ins), current_tid.value());
if(not same_tid_ins_vec.empty())
{
// generate() method would populate these container for next(ins),
// remove them to maintain invariant
current_tid = nullopt;
same_tid_ins_set.erase(std::next(ins));
same_tid_ins_vec.pop_back();
}
}
} }
} }
else else
...@@ -233,20 +245,6 @@ struct auto_gen_root_modules ...@@ -233,20 +245,6 @@ struct auto_gen_root_modules
{ {
generate_run_on_target_modules(mm, p, ins, current_tid.value()); generate_run_on_target_modules(mm, p, ins, current_tid.value());
} }
else if(is_fork_node(ins, current_tid))
{
same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins);
generate_run_on_target_modules(mm, p, std::next(ins), current_tid.value());
if(not same_tid_ins_vec.empty())
{
// generate() method would populate these container for next(ins), remove
// them to maintain invariant
current_tid = nullopt;
same_tid_ins_set.erase(std::next(ins));
same_tid_ins_vec.pop_back();
}
}
else if(tass.at(ins) == current_tid.value()) else if(tass.at(ins) == current_tid.value())
{ {
same_tid_ins_vec.push_back(ins); same_tid_ins_vec.push_back(ins);
...@@ -256,6 +254,7 @@ struct auto_gen_root_modules ...@@ -256,6 +254,7 @@ struct auto_gen_root_modules
{ {
MIGRAPHX_THROW("GenerateRootModules: this case shouldn't occur"); MIGRAPHX_THROW("GenerateRootModules: this case shouldn't occur");
} }
fork_node = is_fork_node(ins, current_tid);
} }
if(not ins->module_inputs().empty()) if(not ins->module_inputs().empty())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment