"...targets/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "4e55c4010656452525462cd3111a920a7c795538"
Commit 20edd2b6 authored by umangyadav's avatar umangyadav
Browse files

partitioner working for simple test

parent 55156faa
......@@ -222,7 +222,13 @@ struct MIGRAPHX_EXPORT module
void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const;
std::vector<module_ref> get_sub_modules(bool shallow = false) const;
// sorts the module in reverse-post order DFS order, this is not considering any implicit deps
// [TODO] right now, this is not considering any implicit deps right now
module& sort();
// if the instruction has the module arguments then all the parameters/instructions used by that
// module from the main/parent module must be calculated before the instruction can be executed.
// therefore, apart from the input instructions, those other instructions are implicit
// dependencies
ins_dep_map calc_implicit_deps() const;
MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const module& m);
......
......@@ -32,7 +32,7 @@ inline namespace MIGRAPHX_INLINE_NS {
/*
given target_assignments, paritions the migraphx program into separate root modules.
*/
void partition(migraphx::program& p, migraphx::target_assignments tass);
void partition(migraphx::program& p, const migraphx::target_assignments& tass);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -21,10 +21,223 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "migraphx/env.hpp"
#include <cstddef>
#include <limits>
#include <migraphx/algorithm.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/partitioner.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <iterator>
#include <unordered_set>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DEBUG_PARTITIONER)
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void partition(migraphx::program& p, target_assignments tass) {}
static literal get_scalar(instruction_ref ins)
{
if(ins->name() == "contiguous")
return get_scalar(ins->inputs().front());
const auto& s = ins->get_shape();
if(s.elements() != 1 && not(s.scalar()))
return {};
if(not ins->can_eval())
return {};
auto e = ins->eval();
literal r{};
// needed for bool as visit_at invokes as() which promotes bool to int8
// Without this we'll break type checks for logical ops that are fused.
if(e.get_shape().type() == shape::bool_type)
{
r = literal{e.at<bool>()};
}
else
{
e.visit_at([&](auto x) { r = literal{x}; });
}
return r;
}
void partition(migraphx::program& p, const target_assignments& tass)
{
auto* mm = p.get_main_module();
// sort the graph in reverse post order DFS order
mm->sort();
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{}))
{
std::cout << "sorted program: \n";
p.debug_print();
}
std::vector<instruction_ref> same_tid_ins_vec;
std::unordered_set<instruction_ref> same_tid_ins_set;
// walk the graph in reverse-DFS order
std::unordered_map<std::size_t, std::size_t> tid_counter;
size_t current_tid = std::numeric_limits<std::size_t>::max();
std::unordered_set<instruction_ref> skip_ins;
for(auto ins : iterator_for(*mm))
{
// gather instructions belonging to the same target_id
// for now, make sure that all the inputs to the insturctions are also from the same
// target_id, if not create another module
// skip all the builtins
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{}))
{
std::cout << "currently processing: \n";
ins->debug_print();
std::cout << "\n";
}
if((starts_with(ins->name(), "@") and ins->name() != "@return") or skip_ins.count(ins) != 0)
{
continue;
}
else if(ins->name() != "@return" and current_tid == std::numeric_limits<std::size_t>::max())
{
current_tid = tass.at(ins);
tid_counter[current_tid] = 0;
same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins);
}
else if(ins->name() != "@return" and tass.at(ins) == current_tid)
{
same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins);
}
else
{
// gather all parameters
std::unordered_set<instruction_ref> params;
// gather all return values
std::unordered_set<instruction_ref> return_ins;
for(auto tins : iterator_for(same_tid_ins_vec))
{
auto inputs = (*tins)->inputs();
auto outputs = (*tins)->outputs();
transform_if(
inputs.cbegin(),
inputs.cend(),
std::inserter(params, params.end()),
[&](auto in_param) {
return (params.count(in_param) == 0 and
same_tid_ins_set.count(in_param) == 0);
},
[&](auto in_param) { return in_param; });
if(std::any_of(outputs.begin(), outputs.end(), [&](const auto out_ins) {
return same_tid_ins_set.count(out_ins) == 0;
}))
{
return_ins.insert(*tins);
}
}
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{}))
{
std::cout << "params ins: \n";
for(auto tmp : iterator_for(params))
{
(*tmp)->debug_print();
}
std::cout << "\n";
std::cout << "return ins: \n";
for(auto tmp : iterator_for(return_ins))
{
(*tmp)->debug_print();
}
std::cout << "\n";
}
auto* tmod = p.create_module("target_mod:" + std::to_string(tid_counter[current_tid]));
std::unordered_map<instruction_ref, instruction_ref> params_map;
std::size_t param_counter = 0;
std::vector<instruction_ref> rot_ins_params;
for(auto pins : iterator_for(params))
{
auto scalar = get_scalar(*pins);
if(scalar.empty())
{
params_map[*pins] = tmod->add_parameter(
"param:" + std::to_string(param_counter++), (*pins)->get_shape());
rot_ins_params.push_back(*pins);
}
else
{
params_map[*pins] = tmod->add_literal(scalar);
}
}
// TODO: what if params_map is empty ?
for(auto tins : iterator_for(same_tid_ins_vec))
{
auto inputs = (*tins)->inputs();
std::vector<instruction_ref> new_inputs;
std::transform(inputs.begin(),
inputs.end(),
std::back_inserter(new_inputs),
[&](auto input_ins) { return params_map.at(input_ins); });
// [TODO]: what if it is has module args ?
auto tmod_tins = tmod->add_instruction((*tins)->get_operator(), new_inputs);
// add parameter to params map so that it can be looked up by other insturctions
params_map[*tins] = tmod_tins;
}
std::vector<instruction_ref> rins;
std::unordered_map<instruction_ref, std::size_t> return_ins_idx_map;
for(auto ritr : iterator_for(return_ins))
{
rins.push_back(params_map.at(*ritr));
return_ins_idx_map[*ritr] = std::distance(ritr, return_ins.begin());
}
tmod->add_return(rins);
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{}))
{
std::cout << "tmod: \n";
tmod->debug_print();
}
// add run_on_target ins
auto rot_ins =
mm->insert_instruction(ins,
make_op("run_on_target", {{"target_id", current_tid}}),
rot_ins_params,
{tmod});
skip_ins.insert(rot_ins);
// fetch return instructions from tuple
for(auto mm_rins : iterator_for(return_ins))
{
auto tuple_elem_ins = mm->insert_instruction(
ins,
make_op("get_tuple_elem", {{"index", return_ins_idx_map.at(*mm_rins)}}),
rot_ins);
skip_ins.insert(tuple_elem_ins);
// replace returns from tmod in main module
mm->replace_instruction(*mm_rins, tuple_elem_ins);
}
dead_code_elimination{}.apply(*mm);
// update current_tid
if(ins->name() != "@return")
{
current_tid = tass.at(ins);
if(tid_counter.count(current_tid) == 0)
{
tid_counter[current_tid] = 0;
}
tid_counter[current_tid]++;
same_tid_ins_set.clear();
same_tid_ins_vec.clear();
same_tid_ins_set.insert(ins);
same_tid_ins_vec.push_back(ins);
}
if(enabled(MIGRAPHX_DEBUG_PARTITIONER{}))
{
std::cout << "program after creation of tmod and rot: \n";
p.debug_print();
}
}
}
dead_code_elimination{}.apply(p);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -41,6 +41,8 @@
#include <migraphx/compile_options.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/generate.hpp>
#include "migraphx/partitioner.hpp"
#include "migraphx/target_assignments.hpp"
#include "test.hpp"
// check if it is custom_op or run_on_module operator
......@@ -158,6 +160,37 @@ bool check_compiled_program(const migraphx::program& p,
return check_compiled;
}
TEST_CASE(multitarget_partition_compile)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {8}};
auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s);
auto z_param = mm->add_parameter("z", s);
auto cpu_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param);
auto gpu_ins = mm->add_instruction(migraphx::make_op("add"), cpu_ins, z_param);
mm->add_return({gpu_ins});
p.debug_print();
migraphx::target_assignments tass;
tass.insert(tass.begin(), std::make_pair(cpu_ins, 1));
tass.insert(tass.begin(), std::make_pair(gpu_ins, 0));
migraphx::partition(p, tass);
p.debug_print();
migraphx::compile_options gpu_opts;
gpu_opts.offload_copy = true;
p.compile({migraphx::make_target("gpu"), migraphx::make_target("cpu")}, {gpu_opts});
p.debug_print();
EXPECT(check_compiled_program(p, {migraphx::make_target("gpu"), migraphx::make_target("cpu")}));
migraphx::parameter_map params;
params["x"] = migraphx::fill_argument(s, 1);
params["y"] = migraphx::fill_argument(s, 2);
params["z"] = migraphx::fill_argument(s, 3);
auto result = p.eval(params).back();
auto gold = migraphx::fill_argument(s, 6);
EXPECT(gold == result);
}
TEST_CASE(multitarget_compile_cpu_gpu)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment