Commit 8db527c7 authored by Umang Yadav's avatar Umang Yadav
Browse files

add unit-tests and fixes

parent 10f76134
...@@ -152,32 +152,29 @@ struct auto_gen_root_modules ...@@ -152,32 +152,29 @@ struct auto_gen_root_modules
bool is_merge_node(migraphx::instruction_ref ins, std::optional<std::size_t> tid) bool is_merge_node(migraphx::instruction_ref ins, std::optional<std::size_t> tid)
{ {
const auto inputs = ins->inputs(); const auto inputs = ins->inputs();
if(std::any_of(inputs.begin(), inputs.end(), [&](auto input_ins) { return std::any_of(inputs.begin(), inputs.end(), [&](auto input_ins) {
if(tass.find(input_ins) != tass.end() and if((skip_ins.find(input_ins) != skip_ins.end()) or
tass.at(ins) != tid.value_or(std::numeric_limits<std::size_t>::max())) (tass.find(input_ins) != tass.end() and
{ tass.at(ins) != tid.value_or(std::numeric_limits<std::size_t>::max())))
return true; {
} return true;
return false; }
})) return false;
return true; });
return false;
} }
bool is_fork_node(migraphx::instruction_ref ins, std::optional<std::size_t> tid) bool is_fork_node(migraphx::instruction_ref ins, std::optional<std::size_t> tid)
{ {
const auto outputs = ins->outputs(); const auto outputs = ins->outputs();
if(std::any_of(outputs.begin(), outputs.end(), [&](auto output_ins) { return std::any_of(outputs.begin(), outputs.end(), [&](auto output_ins) {
if(tass.find(output_ins) != tass.end() and if(tass.find(output_ins) != tass.end() and
tass.at(output_ins) != tid.value_or(std::numeric_limits<std::size_t>::max()) and tass.at(output_ins) != tid.value_or(std::numeric_limits<std::size_t>::max()) and
output_ins->name() != "@return") output_ins->name() != "@return")
{ {
return true; return true;
} }
return false; return false;
})) });
return true;
return false;
} }
void find_subgraphs(migraphx::module_ref mm, migraphx::program& p) void find_subgraphs(migraphx::module_ref mm, migraphx::program& p)
...@@ -196,6 +193,7 @@ struct auto_gen_root_modules ...@@ -196,6 +193,7 @@ struct auto_gen_root_modules
if(enabled(MIGRAPHX_DEBUG_ROOT_GENERATOR{})) if(enabled(MIGRAPHX_DEBUG_ROOT_GENERATOR{}))
{ {
std::cout << "looking at instruction: \n"; std::cout << "looking at instruction: \n";
std::cout << "ins->name() == " << ins->name() << std::endl;
ins->debug_print(); ins->debug_print();
} }
...@@ -214,6 +212,18 @@ struct auto_gen_root_modules ...@@ -214,6 +212,18 @@ struct auto_gen_root_modules
update_tid_counter(current_tid.value()); update_tid_counter(current_tid.value());
same_tid_ins_vec.push_back(ins); same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins); same_tid_ins_set.insert(ins);
if(is_fork_node(ins, current_tid))
{
generate_run_on_target_modules(mm, p, std::next(ins), current_tid.value());
if(not same_tid_ins_vec.empty())
{
// generate() method would populate these container for next(ins),
// remove them to maintain invariant
current_tid = nullopt;
same_tid_ins_set.erase(std::next(ins));
same_tid_ins_vec.pop_back();
}
}
} }
} }
else else
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -42,6 +42,74 @@ ...@@ -42,6 +42,74 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/generate_root_modules.hpp> #include <migraphx/generate_root_modules.hpp>
#include <migraphx/target_assignments.hpp> #include <migraphx/target_assignments.hpp>
#include "test.hpp" #include <test.hpp>
TEST_CASE(fork_case)
{
auto s = migraphx::shape{migraphx::shape::float_type, {8}};
migraphx::target_assignments tass;
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s);
auto z_param = mm->add_parameter("z", s);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param);
auto mul_ins = mm->add_instruction(migraphx::make_op("mul"), add_ins, z_param);
auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), add_ins);
mm->add_return({mul_ins, identity_ins});
tass.insert(tass.begin(), std::make_pair(add_ins, 0));
tass.insert(tass.begin(), std::make_pair(mul_ins, 0));
tass.insert(tass.begin(), std::make_pair(identity_ins, 1));
}
migraphx::generate_root_modules(p1, tass);
p1.debug_print();
};
TEST_CASE(merge_case)
{
migraphx::target_assignments tass;
auto s = migraphx::shape{migraphx::shape::float_type, {8}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s);
auto z_param = mm->add_parameter("z", s);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param);
auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), z_param);
auto mul_ins = mm->add_instruction(migraphx::make_op("mul"), add_ins, identity_ins);
mm->add_return({mul_ins});
tass.insert(tass.begin(), std::make_pair(add_ins, 0));
tass.insert(tass.begin(), std::make_pair(mul_ins, 0));
tass.insert(tass.begin(), std::make_pair(identity_ins, 1));
}
migraphx::generate_root_modules(p1, tass);
p1.debug_print();
};
TEST_CASE(fork_and_merge_case)
{
auto s = migraphx::shape{migraphx::shape::float_type, {8}};
migraphx::target_assignments tass;
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s);
auto z_param = mm->add_parameter("z", s);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param);
auto mul_ins = mm->add_instruction(migraphx::make_op("mul"), add_ins, z_param);
auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), add_ins);
auto merge_ins = mm->add_instruction(migraphx::make_op("sub"), identity_ins, mul_ins);
tass.insert(tass.begin(), std::make_pair(add_ins, 0));
tass.insert(tass.begin(), std::make_pair(mul_ins, 0));
tass.insert(tass.begin(), std::make_pair(identity_ins, 1));
tass.insert(tass.begin(), std::make_pair(merge_ins, 0));
mm->add_return({merge_ins});
}
migraphx::generate_root_modules(p1, tass);
p1.debug_print();
};
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment