Commit 0b2c59f1 authored by umang yadav's avatar umang yadav
Browse files

Use struct to keep state

parent 3633f1a8
...@@ -71,7 +71,7 @@ add_library(migraphx ...@@ -71,7 +71,7 @@ add_library(migraphx
operation.cpp operation.cpp
optimize_module.cpp optimize_module.cpp
pad_calc.cpp pad_calc.cpp
partitioner.cpp generate_root_modules.cpp
pass_manager.cpp pass_manager.cpp
permutation.cpp permutation.cpp
preallocate_param.cpp preallocate_param.cpp
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "migraphx/target_assignments.hpp"
#include <cstddef> #include <cstddef>
#include <limits> #include <limits>
#include <iterator> #include <iterator>
...@@ -31,7 +30,7 @@ ...@@ -31,7 +30,7 @@
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/algorithm.hpp> #include <migraphx/algorithm.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/partitioner.hpp> #include <migraphx/generate_root_modules.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -44,6 +43,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DEBUG_PARTITIONER) ...@@ -44,6 +43,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DEBUG_PARTITIONER)
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
// copied from fuse_pointwise.cpp
static literal get_scalar(instruction_ref ins) static literal get_scalar(instruction_ref ins)
{ {
if(ins->name() == "contiguous") if(ins->name() == "contiguous")
...@@ -68,8 +68,48 @@ static literal get_scalar(instruction_ref ins) ...@@ -68,8 +68,48 @@ static literal get_scalar(instruction_ref ins)
return r; return r;
} }
void update_tid_counter(std::size_t tid, std::unordered_map<std::size_t, std::size_t>& tid_counter) /*
Given target assignments (tass) for the instructions, generate run_on_target modules subgraphs
automatically. Input graph should be uncompiled migraphx program. target assignments (tass) map
should have a map of instruction to target_id. Instructions that are not inside tass map are
considered to be targeted for the "Ref" by default. params, literals and other builtins shouldn't be
part of the tass, only compute and reshaper instructions should be part of tass. Copy, sync and
alloc instructions would be generated by compiler at later stage, so those shouldn't be considered.
(TODO): CustomOps may require special handling.
Step 1:
Identify subgraph boundaries
Ref is used for instructions that do not have assignments.
Boundaries can happen in following cases.
1. Ref --> Target X --> Ref
2. Ref --> Target X --> Target Y
3. Target X --> Target Y --> Target Z , in this case target X and target Z can be same
4. Target X --> "@return"
5. Target X --> Ref --> "@return"
Each of those identified regions could have futher nested sub modules which needs to be handled
separately.
Step 2:
Collect parameters and return instructions for the subgraphs identified in Step 1.
Step 3:
Create modules using information collected in step 2 and insert run_on_target instructions.
*/
struct AutoGenRootModules
{ {
AutoGenRootModules(migraphx::program& p, const target_assignments& target_assignments)
: tass(target_assignments)
{
auto* mm = p.get_main_module();
find_subgraphs(mm, p);
dead_code_elimination{}.apply(p);
}
void update_tid_counter(std::size_t tid)
{
assert(tid != std::numeric_limits<std::size_t>::max()); assert(tid != std::numeric_limits<std::size_t>::max());
if(tid_counter.find(tid) != tid_counter.end()) if(tid_counter.find(tid) != tid_counter.end())
{ {
...@@ -79,18 +119,87 @@ void update_tid_counter(std::size_t tid, std::unordered_map<std::size_t, std::si ...@@ -79,18 +119,87 @@ void update_tid_counter(std::size_t tid, std::unordered_map<std::size_t, std::si
{ {
tid_counter[tid] = 0; tid_counter[tid] = 0;
} }
} }
void find_subgraphs(migraphx::module_ref mm, migraphx::program& p)
{
// sort the graph in reverse post order DFS order
mm->sort();
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{}))
{
std::cout << "sorted module: \n";
mm->debug_print();
}
size_t current_tid = std::numeric_limits<std::size_t>::max();
for(auto ins : iterator_for(*mm))
{
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{}))
{
std::cout << "looking at instruction: \n";
ins->debug_print();
std::cout << "\n";
}
if(ins->name() == "@return")
{
generate_run_on_target_modules(mm, p, ins, current_tid);
}
// skip all params, literal and builtins other than return, skip "run_on_target_mod"
// ins
else if(starts_with(ins->name(), "@") or skip_ins.count(ins) != 0)
{
continue;
}
else if(tass.find(ins) == tass.end())
{
generate_run_on_target_modules(mm, p, ins, current_tid);
}
else if(current_tid == std::numeric_limits<std::size_t>::max())
{
current_tid = tass.at(ins);
update_tid_counter(current_tid);
same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins);
}
else if(tass.at(ins) == current_tid)
{
same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins);
}
else if(tass.at(ins) != current_tid)
{
generate_run_on_target_modules(mm, p, ins, current_tid);
}
else
{
MIGRAPHX_THROW("Partition: this case shouldn't occur");
}
void generate_run_on_target_modules(migraphx::module_ref mm, if(skip_ins.find(ins) == skip_ins.end() and not ins->module_inputs().empty())
{
std::vector<instruction_ref> same_tid_ins_vec_copy = {};
std::unordered_set<instruction_ref> same_tid_ins_set_copy = {};
std::swap(same_tid_ins_set_copy, same_tid_ins_set);
std::swap(same_tid_ins_vec_copy, same_tid_ins_vec);
for(auto sub_mod : ins->module_inputs())
{
find_subgraphs(sub_mod, p);
}
std::swap(same_tid_ins_set_copy, same_tid_ins_set);
std::swap(same_tid_ins_vec_copy, same_tid_ins_vec);
mm->replace_instruction(
ins, ins->get_operator(), ins->inputs(), ins->module_inputs());
}
}
assert(same_tid_ins_set.empty() and same_tid_ins_vec.empty());
}
void generate_run_on_target_modules(migraphx::module_ref mm,
migraphx::program& p, migraphx::program& p,
migraphx::instruction_ref ins, migraphx::instruction_ref ins,
std::size_t& current_tid, std::size_t& current_tid)
const target_assignments& tass, {
std::unordered_set<instruction_ref>& skip_ins,
std::unordered_map<std::size_t, std::size_t>& tid_counter,
std::vector<instruction_ref>& same_tid_ins_vec,
std::unordered_set<instruction_ref>& same_tid_ins_set)
{
assert(same_tid_ins_vec.size() == same_tid_ins_set.size()); assert(same_tid_ins_vec.size() == same_tid_ins_set.size());
if(same_tid_ins_vec.empty()) if(same_tid_ins_vec.empty())
{ {
...@@ -127,8 +236,7 @@ void generate_run_on_target_modules(migraphx::module_ref mm, ...@@ -127,8 +236,7 @@ void generate_run_on_target_modules(migraphx::module_ref mm,
{ {
(*tmp)->debug_print(); (*tmp)->debug_print();
} }
std::cout << "\n"; std::cout << "\n return ins: \n";
std::cout << "return ins: \n";
for(auto tmp : iterator_for(return_ins)) for(auto tmp : iterator_for(return_ins))
{ {
(*tmp)->debug_print(); (*tmp)->debug_print();
...@@ -136,9 +244,10 @@ void generate_run_on_target_modules(migraphx::module_ref mm, ...@@ -136,9 +244,10 @@ void generate_run_on_target_modules(migraphx::module_ref mm,
std::cout << "\n"; std::cout << "\n";
} }
std::cout << "1\n";
auto* tmod = p.create_module("target_mod_" + std::to_string(current_tid) + "_" + auto* tmod = p.create_module("target_mod_" + std::to_string(current_tid) + "_" +
std::to_string(tid_counter[current_tid])); std::to_string(tid_counter[current_tid]));
update_tid_counter(current_tid, tid_counter); update_tid_counter(current_tid);
std::unordered_map<instruction_ref, instruction_ref> params_map; std::unordered_map<instruction_ref, instruction_ref> params_map;
std::size_t param_counter = 0; std::size_t param_counter = 0;
std::vector<instruction_ref> rot_ins_params; std::vector<instruction_ref> rot_ins_params;
...@@ -156,7 +265,9 @@ void generate_run_on_target_modules(migraphx::module_ref mm, ...@@ -156,7 +265,9 @@ void generate_run_on_target_modules(migraphx::module_ref mm,
params_map[*pins] = tmod->add_literal(scalar); params_map[*pins] = tmod->add_literal(scalar);
} }
} }
std::cout << "2\n";
// TODO: what if params_map is empty ? // TODO: what if params_map is empty ?
assert(not params_map.empty());
for(auto tins : iterator_for(same_tid_ins_vec)) for(auto tins : iterator_for(same_tid_ins_vec))
{ {
auto inputs = (*tins)->inputs(); auto inputs = (*tins)->inputs();
...@@ -165,23 +276,25 @@ void generate_run_on_target_modules(migraphx::module_ref mm, ...@@ -165,23 +276,25 @@ void generate_run_on_target_modules(migraphx::module_ref mm,
inputs.end(), inputs.end(),
std::back_inserter(new_inputs), std::back_inserter(new_inputs),
[&](auto input_ins) { return params_map.at(input_ins); }); [&](auto input_ins) { return params_map.at(input_ins); });
// [TODO]: what if it is has module args ? auto tmod_tins = tmod->add_instruction(
auto tmod_tins = (*tins)->get_operator(), new_inputs, (*tins)->module_inputs());
tmod->add_instruction((*tins)->get_operator(), new_inputs, (*tins)->module_inputs());
// add parameter to params map so that it can be looked up by other insturctions // add parameter to params map so that it can be looked up by other insturctions
params_map[*tins] = tmod_tins; params_map[*tins] = tmod_tins;
} }
std::cout << "3\n";
std::vector<instruction_ref> rins; std::vector<instruction_ref> rins;
std::unordered_map<instruction_ref, std::size_t> return_ins_idx_map; std::unordered_map<instruction_ref, std::size_t> return_ins_idx_map;
std::cout << "4\n";
for(auto ritr : iterator_for(return_ins)) for(auto ritr : iterator_for(return_ins))
{ {
rins.push_back(params_map.at(*ritr)); rins.push_back(params_map.at(*ritr));
return_ins_idx_map[*ritr] = std::distance(ritr, return_ins.begin()); return_ins_idx_map[*ritr] = std::distance(ritr, return_ins.begin());
} }
tmod->add_return(rins); tmod->add_return(rins);
std::cout << "5\n";
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{})) if(enabled(MIGRAPHX_DEBUG_PARTITIONER{}))
{ {
std::cout << "tmod: \n"; std::cout << "Created target module: " << tmod->name() << "\n";
tmod->debug_print(); tmod->debug_print();
} }
// add run_on_target ins // add run_on_target ins
...@@ -192,7 +305,9 @@ void generate_run_on_target_modules(migraphx::module_ref mm, ...@@ -192,7 +305,9 @@ void generate_run_on_target_modules(migraphx::module_ref mm,
for(auto mm_rins : iterator_for(return_ins)) for(auto mm_rins : iterator_for(return_ins))
{ {
auto tuple_elem_ins = mm->insert_instruction( auto tuple_elem_ins = mm->insert_instruction(
ins, make_op("get_tuple_elem", {{"index", return_ins_idx_map.at(*mm_rins)}}), rot_ins); ins,
make_op("get_tuple_elem", {{"index", return_ins_idx_map.at(*mm_rins)}}),
rot_ins);
skip_ins.insert(tuple_elem_ins); skip_ins.insert(tuple_elem_ins);
// replace returns from tmod in main module // replace returns from tmod in main module
mm->replace_instruction(*mm_rins, tuple_elem_ins); mm->replace_instruction(*mm_rins, tuple_elem_ins);
...@@ -204,7 +319,7 @@ void generate_run_on_target_modules(migraphx::module_ref mm, ...@@ -204,7 +319,7 @@ void generate_run_on_target_modules(migraphx::module_ref mm,
if(tass.find(ins) != tass.end()) if(tass.find(ins) != tass.end())
{ {
current_tid = tass.at(ins); current_tid = tass.at(ins);
update_tid_counter(current_tid, tid_counter); update_tid_counter(current_tid);
same_tid_ins_set.insert(ins); same_tid_ins_set.insert(ins);
same_tid_ins_vec.push_back(ins); same_tid_ins_vec.push_back(ins);
} }
...@@ -214,146 +329,22 @@ void generate_run_on_target_modules(migraphx::module_ref mm, ...@@ -214,146 +329,22 @@ void generate_run_on_target_modules(migraphx::module_ref mm,
} }
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{})) if(enabled(MIGRAPHX_DEBUG_PARTITIONER{}))
{ {
std::cout << "module after creation of tmod and rot: \n"; std::cout << "Main module after creation of target module: " << tmod->name() << "\n";
mm->debug_print(); mm->debug_print();
} }
}
/*
Given target assignments (tass) for the instructions, generate run_on_target modules subgraphs
automatically. Input graph should be uncompiled migraphx program. target assignments (tass) map
should have a map of instruction to target_id. Instructions that are not inside tass map are
considered to be targeted for the "Ref" by default. params, literals and other builtins shouldn't be
part of the tass, only compute and reshaper instructions should be part of tass. Copy, sync and
alloc instructions would be generated by compiler at later stage, so those shouldn't be considered.
(TODO): CustomOps may require special handling.
Step 1:
Identify subgraph boundaries
Ref is used for instructions that do not have assignments.
Boundaries can happen in following cases.
1. Ref --> Target X --> Ref
2. Ref --> Target X --> Target Y
3. Target X --> Target Y --> Target Z , in this case target X and target Z can be same
4. Target X --> "@return"
5. Target X --> Ref --> "@return"
Each of those identified regions could have futher nested sub modules which needs to be handled
separately.
Step 2:
Collect parameters and return instructions for the subgraphs identified in Step 1.
Step 3:
Create modules using information collected in step 2 and insert run_on_target instructions.
*/
void partition(migraphx::module_ref mm,
migraphx::program& p,
const target_assignments& tass,
std::unordered_map<std::size_t, std::size_t>& tid_counter)
{
mm->sort();
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{}))
{
std::cout << "sorted module: \n";
mm->debug_print();
} }
private:
const target_assignments tass;
std::unordered_map<std::size_t, std::size_t> tid_counter;
std::unordered_set<instruction_ref> skip_ins;
std::vector<instruction_ref> same_tid_ins_vec; std::vector<instruction_ref> same_tid_ins_vec;
std::unordered_set<instruction_ref> same_tid_ins_set; std::unordered_set<instruction_ref> same_tid_ins_set;
// walk the graph in reverse-DFS order };
size_t current_tid = std::numeric_limits<std::size_t>::max();
std::unordered_set<instruction_ref> skip_ins;
for(auto ins : iterator_for(*mm))
{
// gather instructions belonging to the same target_id
// for now, make sure that all the inputs to the insturctions are also from the same
// target_id, if not create another module
// skip all the builtins
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{}))
{
std::cout << "currently processing: \n";
ins->debug_print();
std::cout << "\n";
}
if(skip_ins.count(ins) == 0)
{
if(not ins->module_inputs().empty())
{
for(auto sub_mod : ins->module_inputs())
{
partition(sub_mod, p, tass, tid_counter);
}
mm->replace_instruction(
ins, ins->get_operator(), ins->inputs(), ins->module_inputs());
}
}
if(ins->name() == "@return")
{
generate_run_on_target_modules(mm,
p,
ins,
current_tid,
tass,
skip_ins,
tid_counter,
same_tid_ins_vec,
same_tid_ins_set);
}
// skip all params, literal and builitins other than return, skip "run_on_target_mod" ins
else if(starts_with(ins->name(), "@") or skip_ins.count(ins) != 0)
{
continue;
}
else if(tass.find(ins) == tass.end())
{
generate_run_on_target_modules(mm,
p,
ins,
current_tid,
tass,
skip_ins,
tid_counter,
same_tid_ins_vec,
same_tid_ins_set);
}
else if(current_tid == std::numeric_limits<std::size_t>::max())
{
current_tid = tass.at(ins);
update_tid_counter(current_tid, tid_counter);
same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins);
}
else if(tass.at(ins) == current_tid)
{
same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins);
}
else if(tass.at(ins) != current_tid)
{
generate_run_on_target_modules(mm,
p,
ins,
current_tid,
tass,
skip_ins,
tid_counter,
same_tid_ins_vec,
same_tid_ins_set);
}
else
{
MIGRAPHX_THROW("Partition: this shouldn't occur");
}
}
}
void partition(migraphx::program& p, const target_assignments& tass) void generate_root_modules(migraphx::program& p, const target_assignments& tass)
{ {
auto* mm = p.get_main_module(); AutoGenRootModules(p, tass);
// sort the graph in reverse post order DFS order
std::unordered_map<std::size_t, std::size_t> tid_counter;
partition(mm, p, tass, tid_counter);
dead_code_elimination{}.apply(p);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -30,9 +30,9 @@ ...@@ -30,9 +30,9 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
/* /*
given target_assignments, paritions the migraphx program into separate root modules. given target_assignments, generates root modules for each individual targets inside main module.
*/ */
void partition(migraphx::program& p, const migraphx::target_assignments& tass); void generate_root_modules(migraphx::program& p, const migraphx::target_assignments& tass);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -41,7 +41,7 @@ ...@@ -41,7 +41,7 @@
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include "migraphx/partitioner.hpp" #include <migraphx/generate_root_modules.hpp>
#include "migraphx/target_assignments.hpp" #include "migraphx/target_assignments.hpp"
#include "test.hpp" #include "test.hpp"
...@@ -174,7 +174,7 @@ TEST_CASE(multitarget_compile_cpu_gpu) ...@@ -174,7 +174,7 @@ TEST_CASE(multitarget_compile_cpu_gpu)
migraphx::target_assignments tass; migraphx::target_assignments tass;
tass.insert(tass.begin(), std::make_pair(cpu_ins, 1)); tass.insert(tass.begin(), std::make_pair(cpu_ins, 1));
tass.insert(tass.begin(), std::make_pair(gpu_ins, 0)); tass.insert(tass.begin(), std::make_pair(gpu_ins, 0));
migraphx::partition(p, tass); migraphx::generate_root_modules(p, tass);
migraphx::compile_options gpu_opts; migraphx::compile_options gpu_opts;
gpu_opts.offload_copy = true; gpu_opts.offload_copy = true;
p.compile({migraphx::make_target("gpu"), migraphx::make_target("cpu")}, {gpu_opts}); p.compile({migraphx::make_target("gpu"), migraphx::make_target("cpu")}, {gpu_opts});
...@@ -210,10 +210,10 @@ TEST_CASE(single_target_multi_compile) ...@@ -210,10 +210,10 @@ TEST_CASE(single_target_multi_compile)
iou_threshold, iou_threshold,
score_threshold); score_threshold);
mm->add_return({r}); mm->add_return({r});
// do partition // do target assignments
migraphx::target_assignments tass; migraphx::target_assignments tass;
tass.insert(tass.begin(), std::make_pair(r, 0)); tass.insert(tass.begin(), std::make_pair(r, 0));
migraphx::partition(p, tass); migraphx::generate_root_modules(p, tass);
// compile using multi-target compilation path // compile using multi-target compilation path
migraphx::compile_options gpu_opts; migraphx::compile_options gpu_opts;
gpu_opts.offload_copy = true; gpu_opts.offload_copy = true;
...@@ -232,7 +232,7 @@ TEST_CASE(single_target_multi_compile) ...@@ -232,7 +232,7 @@ TEST_CASE(single_target_multi_compile)
EXPECT(output == gold); EXPECT(output == gold);
} }
TEST_CASE(multitarget_compile_if_then_else_partition) TEST_CASE(multitarget_compile_if_then_else)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
...@@ -264,8 +264,7 @@ TEST_CASE(multitarget_compile_if_then_else_partition) ...@@ -264,8 +264,7 @@ TEST_CASE(multitarget_compile_if_then_else_partition)
tass.insert(tass.begin(), std::make_pair(a1, 0)); tass.insert(tass.begin(), std::make_pair(a1, 0));
tass.insert(tass.begin(), std::make_pair(l2, 1)); tass.insert(tass.begin(), std::make_pair(l2, 1));
tass.insert(tass.begin(), std::make_pair(a2, 1)); tass.insert(tass.begin(), std::make_pair(a2, 1));
migraphx::partition(p, tass); migraphx::generate_root_modules(p, tass);
p.debug_print();
// compile // compile
migraphx::compile_options gpu_opts; migraphx::compile_options gpu_opts;
gpu_opts.offload_copy = true; gpu_opts.offload_copy = true;
...@@ -288,230 +287,8 @@ TEST_CASE(multitarget_compile_if_then_else_partition) ...@@ -288,230 +287,8 @@ TEST_CASE(multitarget_compile_if_then_else_partition)
} }
} }
TEST_CASE(multitarget_compile_if_then_else)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
auto cond = mm->add_parameter("cond", cond_s);
migraphx::shape ds{migraphx::shape::float_type, {2, 3}};
auto x = mm->add_parameter("x", ds);
auto y = mm->add_parameter("y", ds);
auto* then_mod = p.create_module("if_gpu_mod");
std::vector<float> data1(ds.elements(), 1);
auto l1 = then_mod->add_literal(migraphx::literal(ds, data1));
auto gpu_x = then_mod->add_parameter("gpu_x", ds);
auto a1 = then_mod->add_instruction(migraphx::make_op("add"), gpu_x, l1);
then_mod->add_return({a1});
auto* else_mod = p.create_module("else_cpu_mod");
std::vector<float> data2(ds.elements(), 2);
auto l2 = else_mod->add_literal(migraphx::literal(ds, data2));
auto cpu_y = else_mod->add_parameter("cpu_y", ds);
auto a2 = else_mod->add_instruction(migraphx::make_op("mul"), cpu_y, l2);
else_mod->add_return({a2});
auto* run_on_cpu_mod = p.create_module("run_on_cpu");
auto run_cpu_ins = run_on_cpu_mod->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 1}}), {y}, {else_mod});
auto run_cpu_ins_0 = run_on_cpu_mod->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), run_cpu_ins);
run_on_cpu_mod->add_return({run_cpu_ins_0});
auto* run_on_gpu_mod = p.create_module("run_on_gpu");
auto run_gpu_ins = run_on_gpu_mod->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {x}, {then_mod});
auto run_gpu_ins_0 = run_on_gpu_mod->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), run_gpu_ins);
run_on_gpu_mod->add_return({run_gpu_ins_0});
auto ret =
mm->add_instruction(migraphx::make_op("if"), {cond}, {run_on_gpu_mod, run_on_cpu_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
// compile
migraphx::compile_options gpu_opts;
gpu_opts.offload_copy = true;
p.compile({migraphx::make_target("gpu"), migraphx::make_target("cpu")}, {gpu_opts});
EXPECT(check_compiled_program(p, {migraphx::make_target("gpu"), migraphx::make_target("cpu")}));
migraphx::parameter_map params;
params["x"] = migraphx::fill_argument(ds, 2);
params["y"] = migraphx::fill_argument(ds, 3);
for(bool cond_val : {true, false})
{
params["cond"] = migraphx::argument(cond_s, &cond_val);
auto result = p.eval(params).back();
auto gold = migraphx::fill_argument(ds, (cond_val ? 3 : 6));
EXPECT(gold == result);
}
}
// TODO : FPGA compilation is broken right now, below test mentions fpga but doesn't compile for it // TODO : FPGA compilation is broken right now, below test mentions fpga but doesn't compile for it
TEST_CASE(multitarget_compile_nested_if_then_else) TEST_CASE(multitarget_compile_nested_if_then_else)
{
std::unordered_map<std::size_t, std::size_t> counter_map = {{0, 0}, {1, 0}};
migraphx::shape ds{migraphx::shape::float_type, {2, 3}};
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape cond_s{migraphx::shape::bool_type};
auto cond_0 = mm->add_parameter("cond_0", cond_s);
auto cond_1 = mm->add_parameter("cond_1", cond_s);
auto x = mm->add_parameter("x", ds);
auto y = mm->add_parameter("y", ds);
auto z = mm->add_parameter("z", ds);
auto create_test_module = [&](migraphx::program& prog,
const std::vector<migraphx::instruction_ref>& inputs,
std::size_t tid) {
std::string mod_name =
"target_" + std::to_string(tid) + "_" + std::to_string(counter_map[tid]++);
auto* test_mod = prog.create_module(mod_name);
std::vector<float> data(ds.elements(), -1);
auto l1 = test_mod->add_literal(migraphx::literal(ds, data));
auto test_mod_param_0 = test_mod->add_parameter(mod_name + "_param_0", ds);
auto test_mod_param_1 = test_mod->add_parameter(mod_name + "_param_1", ds);
auto test_mod_param_2 = test_mod->add_parameter(mod_name + "_param_2", ds);
auto ins1 = test_mod->add_instruction(migraphx::make_op("add"), test_mod_param_0, l1);
auto ins2 = test_mod->add_instruction(migraphx::make_op("mul"), ins1, test_mod_param_1);
auto ins3 = test_mod->add_instruction(migraphx::make_op("sub"), ins2, test_mod_param_2);
test_mod->add_return({ins3});
auto* run_on_target_mod = prog.create_module("run_on_" + mod_name);
auto run_ins = run_on_target_mod->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", tid}}), inputs, {test_mod});
auto run_ins_0 = run_on_target_mod->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), run_ins);
run_on_target_mod->add_return({run_ins_0});
return run_on_target_mod;
};
// create nested module with multiple targets.
// then_mod has one instruction that runs a module on "ref" and another instruction that
// creates nested modules using "If" that runs on "cpu" and "gpu"
auto* ref_mod = p.create_module("ref_mod");
auto ref_x = ref_mod->add_parameter("ref_x", ds);
auto ref_y = ref_mod->add_parameter("ref_y", ds);
auto ref_add = ref_mod->add_instruction(migraphx::make_op("add"), ref_x, ref_y);
ref_mod->add_return({ref_add});
auto* then_mod = p.create_module("then_mod");
auto then_mod_cond = then_mod->add_parameter("then_mod_cond", cond_s);
auto then_mod_param_0 = then_mod->add_parameter("then_mod_param_0", ds);
auto then_mod_param_1 = then_mod->add_parameter("then_mod_param_1", ds);
auto then_mod_param_2 = then_mod->add_parameter("then_mod_param_2", ds);
auto then_mod_ref_ins =
then_mod->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 3}}),
{then_mod_param_0, then_mod_param_1},
{ref_mod});
auto then_mod_ref_ins_0 = then_mod->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), then_mod_ref_ins);
auto then_mod_if = then_mod->add_instruction(
migraphx::make_op("if"),
{then_mod_cond,
then_mod_param_0,
then_mod_param_1,
then_mod_param_2,
then_mod_ref_ins_0,
then_mod_param_1,
then_mod_param_2},
{create_test_module(p, {then_mod_param_0, then_mod_param_1, then_mod_param_2}, 1),
create_test_module(p, {then_mod_ref_ins_0, then_mod_param_1, then_mod_param_2}, 0)});
auto then_mod_if_0 =
then_mod->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), then_mod_if);
then_mod->add_return({then_mod_if_0});
// create nested else_mod with multiple targets.
// else_mod has one instruction that runs a module on "fpga" and another instruction that
// creates nested modules using "If" that runs on "cpu" and "gpu"
auto* fpga_mod = p.create_module("fpga_mod");
auto fpga_x = fpga_mod->add_parameter("fpga_x", ds);
auto fpga_y = fpga_mod->add_parameter("fpga_y", ds);
auto fpga_add = fpga_mod->add_instruction(migraphx::make_op("add"), fpga_x, fpga_y);
fpga_mod->add_return({fpga_add});
auto* else_mod = p.create_module("else_mod");
auto else_mod_cond = else_mod->add_parameter("else_mod_cond", cond_s);
auto else_mod_param_0 = else_mod->add_parameter("else_mod_param_0", ds);
auto else_mod_param_1 = else_mod->add_parameter("else_mod_param_1", ds);
auto else_mod_param_2 = else_mod->add_parameter("else_mod_param_2", ds);
auto else_mod_fpga_ins =
else_mod->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 2}}),
{else_mod_param_0, else_mod_param_2},
{fpga_mod});
auto else_mod_fpga_ins_0 = else_mod->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), else_mod_fpga_ins);
auto else_mod_if = else_mod->add_instruction(
migraphx::make_op("if"),
{else_mod_cond,
else_mod_fpga_ins_0,
else_mod_param_0,
else_mod_param_1,
else_mod_param_2,
else_mod_param_1,
else_mod_param_0},
{create_test_module(p, {else_mod_fpga_ins_0, else_mod_param_0, else_mod_param_1}, 0),
create_test_module(p, {else_mod_param_2, else_mod_param_1, else_mod_param_0}, 1)});
auto else_mod_if_0 =
else_mod->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), else_mod_if);
else_mod->add_return({else_mod_if_0});
// Create nested and multi-target main module using "If"
auto main_if_ins = mm->add_instruction(
migraphx::make_op("if"), {cond_0, cond_1, x, y, z, cond_1, x, y, z}, {then_mod, else_mod});
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), main_if_ins);
mm->add_return({r});
// compile
migraphx::compile_options gpu_opts;
gpu_opts.offload_copy = true;
p.compile({migraphx::make_target("gpu"),
migraphx::make_target("cpu"),
migraphx::make_target("ref"),
migraphx::make_target("ref")},
{gpu_opts});
EXPECT(check_compiled_program(p,
{migraphx::make_target("gpu"),
migraphx::make_target("cpu"),
migraphx::make_target("ref"),
migraphx::make_target("ref")}));
// do evaluation using different conditions
migraphx::parameter_map params;
float x_i = 2.0;
float y_i = 3.0;
float z_i = 4.0;
params["x"] = migraphx::fill_argument(ds, x_i);
params["y"] = migraphx::fill_argument(ds, y_i);
params["z"] = migraphx::fill_argument(ds, z_i);
// cover all paths with different combination of conditions
std::vector<std::pair<bool, bool>> test_conds = {
{true, true}, {true, false}, {false, true}, {false, false}};
for(auto [cond_val_0, cond_val_1] : test_conds)
{
params["cond_0"] = migraphx::argument(cond_s, &cond_val_0);
params["cond_1"] = migraphx::argument(cond_s, &cond_val_1);
auto result = p.eval(params).back();
// main has one instruction that is : if_then_else
// then mod is doing : {tmp = x+y; (cond) ? (((x-1)*y)-z) : (((tmp-1)*y)-z);}
// else mod is doing : {tmp = x+z; (cond) ? (((tmp-1)*x)-y) : (((z-1)*y)-x);}
float gold_i = -1.0;
if(cond_val_0)
{
float tmp_i = x_i + y_i;
gold_i = (cond_val_1) ? (((x_i - 1) * y_i) - z_i) : (((tmp_i - 1) * y_i) - z_i);
}
else
{
float tmp_i = x_i + z_i;
gold_i = (cond_val_1) ? (((tmp_i - 1) * x_i) - y_i) : (((z_i - 1) * y_i) - x_i);
}
auto gold = migraphx::fill_argument(ds, gold_i);
EXPECT(gold == result);
}
}
// TODO : FPGA compilation is broken right now, below test mentions fpga but doesn't compile for it
TEST_CASE(multitarget_compile_nested_if_then_else_partition)
{ {
std::unordered_map<std::size_t, std::size_t> counter_map = {{0, 0}, {1, 0}}; std::unordered_map<std::size_t, std::size_t> counter_map = {{0, 0}, {1, 0}};
migraphx::shape ds{migraphx::shape::float_type, {2, 3}}; migraphx::shape ds{migraphx::shape::float_type, {2, 3}};
...@@ -525,7 +302,6 @@ TEST_CASE(multitarget_compile_nested_if_then_else_partition) ...@@ -525,7 +302,6 @@ TEST_CASE(multitarget_compile_nested_if_then_else_partition)
auto y = mm->add_parameter("y", ds); auto y = mm->add_parameter("y", ds);
auto z = mm->add_parameter("z", ds); auto z = mm->add_parameter("z", ds);
auto create_test_module = [&](migraphx::program& prog, auto create_test_module = [&](migraphx::program& prog,
const std::vector<migraphx::instruction_ref>& inputs,
std::size_t tid) { std::size_t tid) {
std::string mod_name = std::string mod_name =
"target_" + std::to_string(tid) + "_" + std::to_string(counter_map[tid]++); "target_" + std::to_string(tid) + "_" + std::to_string(counter_map[tid]++);
...@@ -562,8 +338,8 @@ TEST_CASE(multitarget_compile_nested_if_then_else_partition) ...@@ -562,8 +338,8 @@ TEST_CASE(multitarget_compile_nested_if_then_else_partition)
then_mod_ref_ins, then_mod_ref_ins,
then_mod_param_1, then_mod_param_1,
then_mod_param_2}, then_mod_param_2},
{create_test_module(p, {then_mod_param_0, then_mod_param_1, then_mod_param_2}, 1), {create_test_module(p, 1),
create_test_module(p, {then_mod_ref_ins, then_mod_param_1, then_mod_param_2}, 0)}); create_test_module(p, 0)});
auto then_mod_if_0 = auto then_mod_if_0 =
then_mod->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), then_mod_if); then_mod->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), then_mod_if);
then_mod->add_return({then_mod_if_0}); then_mod->add_return({then_mod_if_0});
...@@ -588,8 +364,8 @@ TEST_CASE(multitarget_compile_nested_if_then_else_partition) ...@@ -588,8 +364,8 @@ TEST_CASE(multitarget_compile_nested_if_then_else_partition)
else_mod_param_2, else_mod_param_2,
else_mod_param_1, else_mod_param_1,
else_mod_param_0}, else_mod_param_0},
{create_test_module(p, {else_mod_fpga_ins, else_mod_param_0, else_mod_param_1}, 0), {create_test_module(p, 0),
create_test_module(p, {else_mod_param_2, else_mod_param_1, else_mod_param_0}, 1)}); create_test_module(p, 1)});
auto else_mod_if_0 = auto else_mod_if_0 =
else_mod->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), else_mod_if); else_mod->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), else_mod_if);
else_mod->add_return({else_mod_if_0}); else_mod->add_return({else_mod_if_0});
...@@ -603,18 +379,12 @@ TEST_CASE(multitarget_compile_nested_if_then_else_partition) ...@@ -603,18 +379,12 @@ TEST_CASE(multitarget_compile_nested_if_then_else_partition)
// compile // compile
migraphx::compile_options gpu_opts; migraphx::compile_options gpu_opts;
gpu_opts.offload_copy = true; gpu_opts.offload_copy = true;
std::cout << "before parition\n"; migraphx::generate_root_modules(p, tass);
p.debug_print();
migraphx::partition(p, tass);
std::cout << "after partition\n";
p.debug_print();
p.compile({migraphx::make_target("gpu"), p.compile({migraphx::make_target("gpu"),
migraphx::make_target("cpu"), migraphx::make_target("cpu"),
migraphx::make_target("ref"), migraphx::make_target("ref"),
migraphx::make_target("ref")}, migraphx::make_target("ref")},
{gpu_opts}); {gpu_opts});
std::cout << "after compilation\n";
p.debug_print();
EXPECT(check_compiled_program(p, EXPECT(check_compiled_program(p,
{migraphx::make_target("gpu"), {migraphx::make_target("gpu"),
migraphx::make_target("cpu"), migraphx::make_target("cpu"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment