Commit fa12da23 authored by Umang Yadav's avatar Umang Yadav
Browse files

Changes for the order fix

parent 1e80ceef
......@@ -174,18 +174,51 @@ struct auto_gen_root_modules
*/
bool is_merge_node(migraphx::instruction_ref ins, std::optional<std::size_t> ins_tid)
{
const auto inputs = ins->inputs();
if(inputs.size() == 1)
size_t in_degree = inputs.size();
if(in_degree == 1)
{
return false;
}
return std::any_of(inputs.begin(), inputs.end(), [&](auto input_ins) {
return (
(this->skip_ins.find(input_ins) != skip_ins.end()) or
(tass.find(input_ins) != tass.end() and
tass.at(input_ins) != ins_tid.value_or(std::numeric_limits<std::size_t>::max())));
});
size_t input_from_other_tid_module = 0;
size_t num_default_tids = 0;
size_t num_different_tids = 0;
size_t num_same_tid = 0;
// std::unordered_map<size_t, size_t> in_tid_freq_map;
for(const auto& input_ins : inputs)
{
if(skip_ins.find(input_ins) != skip_ins.end())
{
input_from_other_tid_module++;
}
else if(tass.find(input_ins) == tass.end())
{
num_default_tids++;
}
else if(tass.at(input_ins) != ins_tid)
{
num_different_tids++;
}
else
{
num_same_tid++;
}
}
assert(input_from_other_tid_module + num_default_tids + num_different_tids + num_same_tid ==
in_degree);
if(input_from_other_tid_module > 1)
{
return true;
}
else if(input_from_other_tid_module + num_default_tids == in_degree)
{
return false;
}
else if(num_same_tid + num_default_tids == in_degree)
{
return false;
}
return true;
}
/*
......@@ -200,21 +233,47 @@ struct auto_gen_root_modules
For the partitioner, if any of the fork node's output doesn't have same tid as the fork node
itself then, it is classified as boundary for subgraph.
*/
bool is_fork_node(migraphx::instruction_ref ins, std::optional<std::size_t> ins_tid)
bool is_fork_node(migraphx::instruction_ref ins, std::size_t ins_tid)
{
const auto outputs = ins->outputs();
if(outputs.size() == 1)
{
return false;
}
// if all the outputs are for the "default" or with same tid then it is not a fork but
// rather simply a boundary
std::unordered_map<std::size_t, std::size_t> output_tids;
for(const auto& output_ins : outputs)
{
if(tass.find(output_ins) != tass.end())
{
auto out_tid = tass.at(output_ins);
if(output_tids.find(out_tid) == output_tids.end())
{
output_tids[out_tid] = 1;
}
else
{
output_tids[out_tid]++;
}
}
}
if(output_tids.empty())
{
return false;
}
else if(output_tids.size() == 1 and output_tids.cbegin()->second == outputs.size())
{
return false;
}
return std::any_of(outputs.begin(), outputs.end(), [&](auto output_ins) {
if(output_ins->name() == "return")
{
return false;
}
return (tass.find(output_ins) != tass.end() and
tass.at(output_ins) !=
ins_tid.value_or(std::numeric_limits<std::size_t>::max()));
return (tass.find(output_ins) != tass.end() and tass.at(output_ins) != ins_tid);
});
}
......@@ -262,7 +321,7 @@ struct auto_gen_root_modules
current_tid = std::make_optional<std::size_t>(tass.at(ins));
same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins);
fork_node = is_fork_node(ins, current_tid);
fork_node = is_fork_node(ins, current_tid.value());
}
}
else
......@@ -281,7 +340,8 @@ struct auto_gen_root_modules
{
MIGRAPHX_THROW("GenerateRootModules: this case shouldn't occur");
}
fork_node = is_fork_node(ins, current_tid);
fork_node = is_fork_node(
ins, current_tid.value_or(std::numeric_limits<std::size_t>::max()));
}
if(not ins->module_inputs().empty())
......@@ -315,7 +375,8 @@ struct auto_gen_root_modules
return;
}
// gather all parameters
std::unordered_set<instruction_ref> params;
std::unordered_set<instruction_ref> params_set;
std::vector<instruction_ref> params_vec;
// gather all return values
std::vector<instruction_ref> return_ins;
for(auto tins : iterator_for(same_tid_ins_vec))
......@@ -325,11 +386,15 @@ struct auto_gen_root_modules
transform_if(
inputs.cbegin(),
inputs.cend(),
std::inserter(params, params.end()),
std::back_inserter(params_vec),
[&](auto in_param) {
return (params.count(in_param) == 0 and same_tid_ins_set.count(in_param) == 0);
return (params_set.count(in_param) == 0 and
same_tid_ins_set.count(in_param) == 0);
},
[&](auto in_param) { return in_param; });
[&](auto in_param) {
params_set.insert(in_param);
return in_param;
});
if(std::any_of(outputs.begin(), outputs.end(), [&](const auto out_ins) {
return same_tid_ins_set.count(out_ins) == 0;
}))
......@@ -340,7 +405,7 @@ struct auto_gen_root_modules
if(enabled(MIGRAPHX_DEBUG_ROOT_GENERATOR{}))
{
std::cout << "params ins: \n";
for(auto tmp : iterator_for(params))
for(auto tmp : iterator_for(params_vec))
{
(*tmp)->debug_print();
}
......@@ -357,7 +422,7 @@ struct auto_gen_root_modules
std::unordered_map<instruction_ref, instruction_ref> params_map;
std::size_t param_counter = 0;
std::vector<instruction_ref> rot_ins_params;
for(auto pins : iterator_for(params))
for(auto pins : iterator_for(params_vec))
{
auto scalar = get_scalar(*pins);
if(scalar.empty())
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment