Commit 34b9258a authored by Umang Yadav's avatar Umang Yadav
Browse files

Fixes

parent 40e6b38a
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "migraphx/instruction_ref.hpp"
#include <cstddef> #include <cstddef>
#include <limits> #include <limits>
#include <iterator> #include <iterator>
...@@ -142,16 +141,16 @@ struct auto_gen_root_modules ...@@ -142,16 +141,16 @@ struct auto_gen_root_modules
} }
} }
bool is_different_subgraph(migraphx::instruction_ref ins, size_t tid) bool is_different_subgraph(migraphx::instruction_ref ins, std::optional<std::size_t> tid)
{ {
if(tass.find(ins) == tass.end() or tass.at(ins) != tid) if(tass.find(ins) == tass.end())
{ {
return true; return tid.has_value();
} }
return false; return tass.at(ins) != tid.value_or(std::numeric_limits<std::size_t>::max());
} }
bool is_merge_node(migraphx::instruction_ref ins, size_t tid) bool is_merge_node(migraphx::instruction_ref ins, std::optional<std::size_t> tid)
{ {
const auto inputs = ins->inputs(); const auto inputs = ins->inputs();
if(std::any_of(inputs.begin(), inputs.end(), [&](auto input_ins) { if(std::any_of(inputs.begin(), inputs.end(), [&](auto input_ins) {
...@@ -165,7 +164,7 @@ struct auto_gen_root_modules ...@@ -165,7 +164,7 @@ struct auto_gen_root_modules
return false; return false;
} }
bool is_fork_node(migraphx::instruction_ref ins, size_t tid) bool is_fork_node(migraphx::instruction_ref ins, std::optional<std::size_t> tid)
{ {
const auto outputs = ins->outputs(); const auto outputs = ins->outputs();
if(std::any_of(outputs.begin(), outputs.end(), [&](auto output_ins) { if(std::any_of(outputs.begin(), outputs.end(), [&](auto output_ins) {
...@@ -189,7 +188,7 @@ struct auto_gen_root_modules ...@@ -189,7 +188,7 @@ struct auto_gen_root_modules
mm->debug_print(); mm->debug_print();
} }
size_t current_tid = std::numeric_limits<std::size_t>::max(); std::optional<std::size_t> current_tid = nullopt;
for(auto ins : iterator_for(*mm)) for(auto ins : iterator_for(*mm))
{ {
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{})) if(enabled(MIGRAPHX_DEBUG_PARTITIONER{}))
...@@ -205,38 +204,48 @@ struct auto_gen_root_modules ...@@ -205,38 +204,48 @@ struct auto_gen_root_modules
{ {
continue; continue;
} }
else if(ins->name() == "@return" or is_different_subgraph(ins, current_tid) or if(not current_tid.has_value())
is_merge_node(ins, current_tid))
{ {
generate_run_on_target_modules(mm, p, ins, current_tid); if(tass.find(ins) == tass.end())
}
else if(is_fork_node(ins, current_tid))
{
same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins);
generate_run_on_target_modules(mm, p, std::next(ins), current_tid);
if(not same_tid_ins_vec.empty())
{ {
current_tid = std::numeric_limits<std::size_t>::max(); continue;
same_tid_ins_set.erase(ins); }
same_tid_ins_vec.pop_back(); else
{
current_tid = std::make_optional<std::size_t>(tass.at(ins));
update_tid_counter(current_tid.value());
same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins);
} }
}
else if(current_tid == std::numeric_limits<std::size_t>::max())
{
current_tid = tass.at(ins);
update_tid_counter(current_tid);
same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins);
}
else if(tass.at(ins) == current_tid)
{
same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins);
} }
else else
{ {
MIGRAPHX_THROW("Partition: this case shouldn't occur"); if(ins->name() == "@return" or is_different_subgraph(ins, current_tid) or
is_merge_node(ins, current_tid))
{
generate_run_on_target_modules(mm, p, ins, current_tid.value());
}
else if(is_fork_node(ins, current_tid))
{
same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins);
generate_run_on_target_modules(mm, p, std::next(ins), current_tid.value());
if(not same_tid_ins_vec.empty())
{
current_tid = nullopt;
same_tid_ins_set.erase(ins);
same_tid_ins_vec.pop_back();
}
}
else if(tass.at(ins) == current_tid.value())
{
same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins);
}
else
{
MIGRAPHX_THROW("Partition: this case shouldn't occur");
}
} }
if(skip_ins.find(ins) == skip_ins.end() and not ins->module_inputs().empty()) if(skip_ins.find(ins) == skip_ins.end() and not ins->module_inputs().empty())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment