Commit fa12da23 authored by Umang Yadav's avatar Umang Yadav
Browse files

Changes for the order fix

parent 1e80ceef
...@@ -174,18 +174,51 @@ struct auto_gen_root_modules ...@@ -174,18 +174,51 @@ struct auto_gen_root_modules
*/ */
bool is_merge_node(migraphx::instruction_ref ins, std::optional<std::size_t> ins_tid) bool is_merge_node(migraphx::instruction_ref ins, std::optional<std::size_t> ins_tid)
{ {
const auto inputs = ins->inputs(); const auto inputs = ins->inputs();
if(inputs.size() == 1) size_t in_degree = inputs.size();
if(in_degree == 1)
{ {
return false; return false;
} }
return std::any_of(inputs.begin(), inputs.end(), [&](auto input_ins) { size_t input_from_other_tid_module = 0;
return ( size_t num_default_tids = 0;
(this->skip_ins.find(input_ins) != skip_ins.end()) or size_t num_different_tids = 0;
(tass.find(input_ins) != tass.end() and size_t num_same_tid = 0;
tass.at(input_ins) != ins_tid.value_or(std::numeric_limits<std::size_t>::max()))); // std::unordered_map<size_t, size_t> in_tid_freq_map;
}); for(const auto& input_ins : inputs)
{
if(skip_ins.find(input_ins) != skip_ins.end())
{
input_from_other_tid_module++;
}
else if(tass.find(input_ins) == tass.end())
{
num_default_tids++;
}
else if(tass.at(input_ins) != ins_tid)
{
num_different_tids++;
}
else
{
num_same_tid++;
}
}
assert(input_from_other_tid_module + num_default_tids + num_different_tids + num_same_tid ==
in_degree);
if(input_from_other_tid_module > 1)
{
return true;
}
else if(input_from_other_tid_module + num_default_tids == in_degree)
{
return false;
}
else if(num_same_tid + num_default_tids == in_degree)
{
return false;
}
return true;
} }
/* /*
...@@ -200,21 +233,47 @@ struct auto_gen_root_modules ...@@ -200,21 +233,47 @@ struct auto_gen_root_modules
For the partitioner, if any of the fork node's output doesn't have same tid as the fork node For the partitioner, if any of the fork node's output doesn't have same tid as the fork node
itself then, it is classified as boundary for subgraph. itself then, it is classified as boundary for subgraph.
*/ */
bool is_fork_node(migraphx::instruction_ref ins, std::optional<std::size_t> ins_tid) bool is_fork_node(migraphx::instruction_ref ins, std::size_t ins_tid)
{ {
const auto outputs = ins->outputs(); const auto outputs = ins->outputs();
if(outputs.size() == 1) if(outputs.size() == 1)
{ {
return false; return false;
} }
// if all the outputs are for the "default" or with same tid then it is not a fork but
// rather simply a boundary
std::unordered_map<std::size_t, std::size_t> output_tids;
for(const auto& output_ins : outputs)
{
if(tass.find(output_ins) != tass.end())
{
auto out_tid = tass.at(output_ins);
if(output_tids.find(out_tid) == output_tids.end())
{
output_tids[out_tid] = 1;
}
else
{
output_tids[out_tid]++;
}
}
}
if(output_tids.empty())
{
return false;
}
else if(output_tids.size() == 1 and output_tids.cbegin()->second == outputs.size())
{
return false;
}
return std::any_of(outputs.begin(), outputs.end(), [&](auto output_ins) { return std::any_of(outputs.begin(), outputs.end(), [&](auto output_ins) {
if(output_ins->name() == "return") if(output_ins->name() == "return")
{ {
return false; return false;
} }
return (tass.find(output_ins) != tass.end() and return (tass.find(output_ins) != tass.end() and tass.at(output_ins) != ins_tid);
tass.at(output_ins) !=
ins_tid.value_or(std::numeric_limits<std::size_t>::max()));
}); });
} }
...@@ -262,7 +321,7 @@ struct auto_gen_root_modules ...@@ -262,7 +321,7 @@ struct auto_gen_root_modules
current_tid = std::make_optional<std::size_t>(tass.at(ins)); current_tid = std::make_optional<std::size_t>(tass.at(ins));
same_tid_ins_vec.push_back(ins); same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins); same_tid_ins_set.insert(ins);
fork_node = is_fork_node(ins, current_tid); fork_node = is_fork_node(ins, current_tid.value());
} }
} }
else else
...@@ -281,7 +340,8 @@ struct auto_gen_root_modules ...@@ -281,7 +340,8 @@ struct auto_gen_root_modules
{ {
MIGRAPHX_THROW("GenerateRootModules: this case shouldn't occur"); MIGRAPHX_THROW("GenerateRootModules: this case shouldn't occur");
} }
fork_node = is_fork_node(ins, current_tid); fork_node = is_fork_node(
ins, current_tid.value_or(std::numeric_limits<std::size_t>::max()));
} }
if(not ins->module_inputs().empty()) if(not ins->module_inputs().empty())
...@@ -315,7 +375,8 @@ struct auto_gen_root_modules ...@@ -315,7 +375,8 @@ struct auto_gen_root_modules
return; return;
} }
// gather all parameters // gather all parameters
std::unordered_set<instruction_ref> params; std::unordered_set<instruction_ref> params_set;
std::vector<instruction_ref> params_vec;
// gather all return values // gather all return values
std::vector<instruction_ref> return_ins; std::vector<instruction_ref> return_ins;
for(auto tins : iterator_for(same_tid_ins_vec)) for(auto tins : iterator_for(same_tid_ins_vec))
...@@ -325,11 +386,15 @@ struct auto_gen_root_modules ...@@ -325,11 +386,15 @@ struct auto_gen_root_modules
transform_if( transform_if(
inputs.cbegin(), inputs.cbegin(),
inputs.cend(), inputs.cend(),
std::inserter(params, params.end()), std::back_inserter(params_vec),
[&](auto in_param) { [&](auto in_param) {
return (params.count(in_param) == 0 and same_tid_ins_set.count(in_param) == 0); return (params_set.count(in_param) == 0 and
same_tid_ins_set.count(in_param) == 0);
}, },
[&](auto in_param) { return in_param; }); [&](auto in_param) {
params_set.insert(in_param);
return in_param;
});
if(std::any_of(outputs.begin(), outputs.end(), [&](const auto out_ins) { if(std::any_of(outputs.begin(), outputs.end(), [&](const auto out_ins) {
return same_tid_ins_set.count(out_ins) == 0; return same_tid_ins_set.count(out_ins) == 0;
})) }))
...@@ -340,7 +405,7 @@ struct auto_gen_root_modules ...@@ -340,7 +405,7 @@ struct auto_gen_root_modules
if(enabled(MIGRAPHX_DEBUG_ROOT_GENERATOR{})) if(enabled(MIGRAPHX_DEBUG_ROOT_GENERATOR{}))
{ {
std::cout << "params ins: \n"; std::cout << "params ins: \n";
for(auto tmp : iterator_for(params)) for(auto tmp : iterator_for(params_vec))
{ {
(*tmp)->debug_print(); (*tmp)->debug_print();
} }
...@@ -357,7 +422,7 @@ struct auto_gen_root_modules ...@@ -357,7 +422,7 @@ struct auto_gen_root_modules
std::unordered_map<instruction_ref, instruction_ref> params_map; std::unordered_map<instruction_ref, instruction_ref> params_map;
std::size_t param_counter = 0; std::size_t param_counter = 0;
std::vector<instruction_ref> rot_ins_params; std::vector<instruction_ref> rot_ins_params;
for(auto pins : iterator_for(params)) for(auto pins : iterator_for(params_vec))
{ {
auto scalar = get_scalar(*pins); auto scalar = get_scalar(*pins);
if(scalar.empty()) if(scalar.empty())
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment