Commit 4332ccf6 authored by Paul's avatar Paul
Browse files

Refactor

parent 9b5e0c18
......@@ -31,7 +31,6 @@ int main(int argc, char const* argv[])
std::cout << "Allocating params ... " << std::endl;
auto m = create_param_map(p);
std::cout << "Running performance report ... " << std::endl;
p.perf_report(std::cout, n, m);
}
}
#include <migraphx/dom_info.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// A unified interface to visit programs top-down or bottom-up.
struct program_visitor
{
program* p_program;
bool reversed;
instruction_ref begin() { return reversed ? std::prev(p_program->end()) : p_program->begin(); }
instruction_ref end() { return reversed ? p_program->begin() : std::prev(p_program->end()); }
instruction_ref next(instruction_ref ins) { return reversed ? std::prev(ins) : std::next(ins); }
const std::vector<instruction_ref>& get_inputs(instruction_ref ins) const
{
return reversed ? ins->outputs() : ins->inputs();
}
};
// Query whether ins1 strictly post-dominates ins2. ins1 strictly post-dominates
// ins2 if ins1 post-dominates ins2 and ins1 is not ins2.
//
bool dom_info::strictly_post_dominates(const instruction* ins1, const instruction* ins2)
{
if(ins1 != ins2)
{
const instruction* iter = ins2;
while(instr2_ipdom.find(iter) != instr2_ipdom.end())
{
if(ins1 == instr2_ipdom[iter])
return true;
iter = instr2_ipdom[iter];
}
}
return false;
}
// Among p_ins's dominators, find ones that strictly dominates or post-dominators others.
//
void dom_info::find_dom_tree(
std::unordered_map<const instruction*, std::set<const instruction*>>& instr2_doms,
const instruction* p_ins,
std::unordered_map<const instruction*, const instruction*>& instr2_dom_tree,
std::unordered_map<const instruction*, const instruction*>& idom)
{
for(auto& iter1 : instr2_doms[p_ins])
{
auto dom_check = [& dom_tree = idom, ins1 = iter1 ](const instruction* ins2)
{
if(ins1 == ins2)
return false;
const instruction* iter = ins2;
;
while(dom_tree.find(iter) != dom_tree.end())
{
if(ins1 == dom_tree[iter])
return true;
iter = dom_tree[iter];
}
return false;
};
// check whether iter1 strictly dominates or post-dominates any other notes in
// p_ins's dominators or post-dominators.
if(!std::any_of(instr2_doms[p_ins].begin(), instr2_doms[p_ins].end(), dom_check))
{
assert(instr2_dom_tree.find(p_ins) == instr2_dom_tree.end());
instr2_dom_tree[p_ins] = iter1;
}
}
}
// Compute dominator or post-dominator. Instructions that do not use
// streams are left out.
//
void dom_info::compute_dom(bool reversed)
{
std::size_t num_of_instrs = p_program->size();
if(num_of_instrs == 0)
return;
std::unordered_map<const instruction*, std::set<const instruction*>> instr2_doms;
std::unordered_map<const instruction*, int> instr2_points;
int cur_points = reversed ? num_of_instrs - 1 : 0;
bool seen_stream = false;
program_visitor vis{p_program, reversed};
std::unordered_map<const instruction*, const instruction*>& instr2_dom_tree =
(reversed ? instr2_ipdom : instr2_idom);
for(auto ins = vis.begin(), end = vis.end();; ins = vis.next(ins))
{
const instruction* p_ins = &(*ins);
instr2_points[p_ins] = cur_points;
if(ins->get_stream() < 0)
{
if(reversed)
cur_points--;
else
cur_points++;
;
if(ins == end)
break;
continue;
}
seen_stream = true;
const instruction* p_tmp = nullptr;
int cnt = 0;
// find dominators.
for(auto&& iter : vis.get_inputs(ins))
{
if(iter->get_stream() < 0)
continue;
const instruction* p_arg = &(*iter);
cnt++;
assert(instr2_doms.find(p_arg) != instr2_doms.end());
if(p_tmp == nullptr)
instr2_doms[p_ins] = instr2_doms[p_arg];
else
instr2_doms[p_ins] = set_intersection(instr2_doms[p_ins], instr2_doms[p_arg]);
p_tmp = p_arg;
}
// find immediate dominators.
if(cnt == 1)
{
instr2_dom_tree[p_ins] = p_tmp;
}
else if(cnt > 0)
{
std::unordered_map<const instruction*, const instruction*>& idom =
reversed ? instr2_ipdom : instr2_idom;
find_dom_tree(instr2_doms, p_ins, instr2_dom_tree, idom);
}
instr2_doms[p_ins].insert(p_ins);
if(ins == end)
break;
if(reversed)
cur_points--;
else
cur_points++;
}
if(seen_stream)
{
MIGRAPHX_DEBUG(dump_doms(instr2_points, reversed));
}
}
// Identify split points. A split point has more than one
// outputs that are executed in different streams.
bool dom_info::is_split_point(instruction_ref ins)
{
int stream = -1;
for(auto&& arg : ins->outputs())
{
int arg_stream = arg->get_stream();
if(arg_stream < 0)
continue;
if((stream >= 0) && (arg_stream != stream))
return true;
stream = arg_stream;
}
return false;
}
// Identify merge points. A merge point has more than one
// inputs that are executed in different streams.
bool dom_info::is_merge_point(instruction_ref ins)
{
int stream = -1;
for(auto&& arg : ins->inputs())
{
int arg_stream = arg->get_stream();
if(arg_stream < 0)
continue;
if((stream >= 0) && (arg_stream != stream))
return true;
stream = arg_stream;
}
return false;
}
// Propagate split points through the graph and identify concurrent instructions.
// Concurrent instructions have the same split points and different streams.
//
void dom_info::propagate_splits(
int num_of_streams,
std::unordered_map<const instruction*, std::vector<std::vector<const instruction*>>>&
concur_instrs,
std::unordered_map<const instruction*, int>& instr2_points)
{
std::unordered_map<instruction_ref, bool> is_split;
std::unordered_map<instruction_ref, bool> is_merge;
std::unordered_map<instruction_ref, std::set<const instruction*>> split_from;
int cur_points = 0;
instr2_points.clear();
for(auto ins : iterator_for(*p_program))
{
const instruction* p_iter = &(*ins);
instr2_points[p_iter] = cur_points++;
int stream = ins->get_stream();
if(stream < 0)
continue;
is_split[ins] = is_split_point(ins);
is_merge[ins] = is_merge_point(ins);
for(auto&& arg : ins->inputs())
{
// Input is a split point.
if(is_split.find(arg) != is_split.end())
split_from[ins].insert(&(*arg));
// Union inputs' split points.
if((split_from.find(arg) != split_from.end()) && !split_from[arg].empty())
{
if(split_from.find(ins) == split_from.end())
split_from[ins] = split_from[arg];
else
split_from[ins] = set_union(split_from[ins], split_from[arg]);
}
}
if(is_merge[ins])
{
assert(split_from.find(ins) != split_from.end());
std::set<const instruction*> del_set;
// post-dominator kills split point.
for(auto& split : split_from[ins])
{
if(strictly_post_dominates(p_iter, split))
del_set.insert(split);
}
split_from[ins] = set_difference(split_from[ins], del_set);
}
if(split_from.find(ins) != split_from.end())
{
// Collect concur instructions for each split point.
for(auto& split : split_from[ins])
{
if(concur_instrs.find(split) == concur_instrs.end())
{
std::vector<std::vector<const instruction*>> instr_stack;
instr_stack.resize(num_of_streams);
concur_instrs[split] = instr_stack;
}
concur_instrs[split][stream].push_back(p_iter);
}
}
}
}
#ifdef MIGRAPHX_DEBUG_OPT
void dom_info::dump_doms(std::unordered_map<const instruction*, int>& instr2_points, bool post_dom)
{
std::cout << "---dominator tree---" << std::endl;
for(auto ins : iterator_for(*p_program))
{
const instruction* p_ins = &(*ins);
if(!post_dom && (instr2_idom.find(p_ins) != instr2_idom.end()))
{
const instruction* idom = instr2_idom[p_ins];
std::cout << "@" << instr2_points[p_ins] << " imm dominator: "
<< "@" << instr2_points[idom] << std::endl;
}
if(post_dom && (instr2_ipdom.find(p_ins) != instr2_ipdom.end()))
{
const instruction* ipdom = instr2_ipdom[p_ins];
std::cout << "@" << instr2_points[p_ins] << " imm post domimator: "
<< "@" << instr2_points[ipdom] << std::endl;
}
}
}
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -8,9 +8,10 @@ void memory_coloring::apply(program& p) const
{
if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{}))
{
memory_coloring_impl opt(&p, allocation_op, verify, num_of_streams, f_concur);
memory_coloring_impl opt(&p, allocation_op, verify);
opt.run();
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -11,8 +11,6 @@ void memory_coloring_impl::run()
if(num_of_lives != 0)
{
MIGRAPHX_DEBUG(dump_intervals());
if(num_of_streams > 0)
add_stream_conflicts();
// Coloring
while(!alloc_queue.empty())
{
......@@ -154,8 +152,7 @@ void memory_coloring_impl::build()
interval->segment.end = cur_points;
interval->segment.vn = ++max_value_number;
interval->add_use(cur_points);
instr2_live[p_arg] = interval;
instr2_live[&(*arg)] = interval;
instr2_live[p_arg] = interval;
add_conflicts(live_set, max_value_number);
live_set.insert(max_value_number);
live_ranges[max_value_number] = &(interval->segment);
......@@ -168,7 +165,6 @@ void memory_coloring_impl::build()
interval_ptr interval = instr2_live[p_arg];
interval->add_use(cur_points);
assert(live_set.find(interval->id) != live_set.end());
instr2_live[&(*arg)] = interval;
}
}
if(is_dead)
......@@ -262,57 +258,6 @@ void memory_coloring_impl::verify()
}
}
// Add conflicts of concurrent instructions into conflict table.
//
void memory_coloring_impl::add_stream_conflicts(std::vector<const instruction*>& i1,
std::vector<const instruction*>& i2)
{
for(auto& ins1 : i1)
{
if(instr2_live.find(ins1) == instr2_live.end())
continue;
interval_ptr interval1 = instr2_live[ins1];
int id1 = interval1->id;
for(auto& ins2 : i2)
{
if(instr2_live.find(ins2) == instr2_live.end())
continue;
interval_ptr interval2 = instr2_live[ins2];
int id2 = interval2->id;
conflict_table[id1].insert(id2);
conflict_table[id2].insert(id1);
#ifdef MIGRAPHX_DEBUG_OPT
std::cout << "@" << instr2_points[ins1] << " id:" << id1 << " => "
<< "@" << instr2_points[ins2] << " id:" << id2 << std::endl;
#endif
}
}
}
// Identify concurrent instructions in different streams and add conflicts to
// conflict table.
//
void memory_coloring_impl::add_stream_conflicts()
{
std::unordered_map<const instruction*, std::vector<std::vector<const instruction*>>>
concur_instrs;
f_concur.get_concur(p_program, num_of_streams, concur_instrs, instr2_points);
MIGRAPHX_DEBUG(dump_concur_instrs(concur_instrs));
for(auto& iter : concur_instrs)
{
for(auto s1 = 0; s1 < num_of_streams; ++s1)
{
std::vector<const instruction*>& i1 = iter.second[s1];
for(auto s2 = s1 + 1; s2 < num_of_streams; ++s2)
{
std::vector<const instruction*>& i2 = iter.second[s2];
add_stream_conflicts(i1, i2);
}
}
}
}
#ifdef MIGRAPHX_DEBUG_OPT
void memory_coloring_impl::dump(const std::string& str) { std::cout << str << std::endl; }
......@@ -344,29 +289,6 @@ void memory_coloring_impl::dump_intervals()
}
}
void memory_coloring_impl::dump_concur_instrs(
std::unordered_map<const instruction*, std::vector<std::vector<const instruction*>>>&
concur_instrs)
{
for(auto iter = concur_instrs.begin(), end = concur_instrs.end(); iter != end; ++iter)
{
std::cout << "concurrent instructions for split @" << instr2_points[iter->first]
<< std::endl;
for(auto s1 = 0; s1 < num_of_streams; ++s1)
{
std::vector<const instruction*>& instrs = iter->second[s1];
if(instrs.empty())
continue;
std::cout << "stream:" << s1 << std::endl;
for(auto ins = instrs.begin(), ins_end = instrs.end(); ins != ins_end; ++ins)
{
std::cout << " @" << instr2_points[*ins];
}
std::cout << std::endl;
}
}
}
// map liveness tracking point to instruction enum.
static int get_ins_enum(int x)
{
......@@ -409,5 +331,6 @@ void live_interval::dump()
}
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#include <migraphx/common_header.hpp>
#include "common_header.hpp"
#include <migraphx/config.hpp>
#include <migraphx/find_concur.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -53,12 +52,8 @@ using interval_ptr = live_interval*;
struct memory_coloring_impl
{
memory_coloring_impl(program* p, std::string alloc_op, bool p_verify, int num, find_concur f)
: p_program(p),
allocation_op(std::move(alloc_op)),
enable_verify(p_verify),
num_of_streams(num),
f_concur(std::move(f))
memory_coloring_impl(program* p, std::string alloc_op, bool p_verify)
: p_program(p), allocation_op(std::move(alloc_op)), enable_verify(p_verify)
{
instr2_live.clear();
live_ranges.clear();
......@@ -79,8 +74,6 @@ struct memory_coloring_impl
conflict_table[val].insert(iter);
}
}
void add_stream_conflicts();
void add_stream_conflicts(std::vector<const instruction*>&, std::vector<const instruction*>&);
void build();
void run();
void rewrite();
......@@ -112,8 +105,6 @@ struct memory_coloring_impl
void dump(const std::string&);
void dump_program();
void dump_intervals();
void dump_concur_instrs(
std::unordered_map<const instruction*, std::vector<std::vector<const instruction*>>>&);
#endif
struct ordering
{
......@@ -139,7 +130,6 @@ struct memory_coloring_impl
return (i1->offset > i2->offset);
}
};
program* p_program;
std::unordered_map<const instruction*, interval_ptr> instr2_live;
// universe of live intervals.
......@@ -150,7 +140,7 @@ struct memory_coloring_impl
std::unordered_map<int, std::set<int>> conflict_table;
// Priority queue for coloring.
std::priority_queue<interval_ptr, std::vector<interval_ptr>, ordering> alloc_queue;
std::unordered_map<const instruction*, int> instr2_points;
int num_of_lives;
int max_value_number;
long long required_bytes;
......@@ -162,9 +152,8 @@ struct memory_coloring_impl
bool unify_literals;
std::string allocation_op{};
bool enable_verify;
int num_of_streams;
find_concur f_concur;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#include <migraphx/pre_scheduling.hpp>
#include "pre_scheduling_impl.hpp"
namespace migraphx {
void pre_scheduling::apply(program& p) const
{
if(!enabled(MIGRAPHX_DISABLE_PRE_SCHEDULING{}))
{
pre_scheduling_impl opt(&p, weight_func, num_of_streams, insert_instr, verify);
opt.run();
}
}
} // namespace migraphx
#include "pre_scheduling_impl.hpp"
#include <migraphx/iterator_for.hpp>
#include <migraphx/pass_config.hpp>
#include <stack>
namespace migraphx {
// Compute accumulated weights for each node in the DAG. Collect exit nodes
// and sort them according to accumulated weights.
//
void pre_scheduling_impl::compute_weights()
{
int ndx = 0;
std::unordered_map<dag_node*, bool> visited;
for(auto ins : iterator_for(*p_program))
{
dag_node& node = nodes[ndx];
std::pair<int, int> weight = weight_func(ins->get_operator());
node.weight = weight.first;
node.run_on_cpu = weight.second;
node.weight_sum += node.weight;
visited.clear();
for(auto&& arg : ins->inputs())
{
assert(instr2_node.find(arg) != instr2_node.end());
dag_node* def_node = instr2_node[arg];
if(visited.find(def_node) == visited.end())
{
node.weight_sum += def_node->weight_sum;
visited[def_node] = true;
}
}
if(ins->outputs().empty())
{
exit_nodes.push_back(&node);
}
node.ins = ins;
node.ins_ndx = ndx++;
instr2_node[ins] = &node;
}
int size = exit_nodes.size();
if(size > 1)
{
std::sort(exit_nodes.begin(), exit_nodes.end(), compare_exit_nodes);
}
}
// Do topology sort according to accumulated weight. Identify critial paths.
// Schedule nodes into streams. Reorder instructions according to topological
// order and annoate streams and events in the instructions.
//
void pre_scheduling_impl::reorder()
{
std::list<dag_node*> sorted_nodes;
std::stack<dag_node*> stack;
std::priority_queue<dag_node*, std::vector<dag_node*>, weighted_topology_ordering> child_queue;
std::unordered_map<dag_node*, bool> visited;
std::unordered_map<dag_node*, bool> dequeued;
for(auto&& node : exit_nodes)
{
stack.push(node);
node->partition = partition_info.create_partition();
partition_info.add_weight(node);
while(!stack.empty())
{
auto cur = stack.top();
if(dequeued.find(cur) != dequeued.end())
{
stack.pop();
continue;
}
else if((visited.find(cur) != visited.end()) || cur->ins->inputs().empty())
{
stack.pop();
sorted_nodes.push_back(cur);
dequeued[cur] = true;
continue;
}
// sort child nodes.
for(auto&& arg : cur->ins->inputs())
{
dag_node* child_node = instr2_node[arg];
if(dequeued.find(child_node) == dequeued.end())
{
child_queue.push(child_node);
}
}
// Last item in queue is on critical path.
while(!child_queue.empty())
{
dag_node* child = child_queue.top();
stack.push(child);
child_queue.pop();
if(child->weight_sum < min_partition_threshold)
child->partition = cur->partition;
else if(!child_queue.empty())
child->partition = partition_info.create_partition();
else
{
cur->first_child = child;
child->partition = cur->partition;
}
partition_info.add_weight(child);
}
visited[cur] = true;
}
}
#ifdef MIGRAPHX_DEBUG_OPT
MIGRAPHX_DEBUG(dump("---After weighted topology sort---"));
MIGRAPHX_DEBUG(dump(sorted_nodes));
#endif
schedule(sorted_nodes);
splice(sorted_nodes);
annotate(sorted_nodes);
if(enable_verify)
verify();
}
// Assign stream to nodes according to load balance.
//
int pre_scheduling_impl::get_stream(stream_info& info, dag_node* node) const
{
int max_cycle = info.max_cycle;
if(max_cycle == 0)
return 0;
int partition_load = partition_info.weight_sum[node->partition];
int earliest_cycle = node->earliest_cycle;
int min_cycle = -1;
int min_cycle_stream = -1;
for(auto stream = 0; stream < num_of_streams; ++stream)
{
int cycle = std::max(info.next_cycles[stream], earliest_cycle);
if((cycle < max_cycle) && ((max_cycle - cycle) > partition_load))
return stream;
if((min_cycle_stream == -1) || (cycle < min_cycle))
{
min_cycle = cycle;
min_cycle_stream = stream;
}
}
return min_cycle_stream;
}
// Record the stream-assignment.
//
void pre_scheduling_impl::record(stream_info& info, dag_node* node)
{
int stream = node->stream;
int next_cycle = info.next_cycles[stream];
node->sched_cycle = std::max(node->earliest_cycle, next_cycle);
next_cycle = node->sched_cycle + node->weight;
info.next_cycles[stream] = next_cycle;
info.max_cycle = std::max(info.max_cycle, next_cycle);
for(auto&& arg : node->ins->outputs())
{
assert(instr2_node.find(arg) != instr2_node.end());
dag_node* use_node = instr2_node[arg];
use_node->earliest_cycle = std::max(use_node->earliest_cycle, next_cycle);
}
if(node->can_use_stream())
instr2_stream[node->ins] = stream;
}
// Assign nodes to streams.
//
void pre_scheduling_impl::schedule(std::list<dag_node*>& sorted_nodes)
{
if(num_of_streams == 0)
return;
stream_info info(num_of_streams);
std::unordered_map<int, int> partition2_stream;
partition2_stream.clear();
for(auto&& node : sorted_nodes)
{
int cur_partition = node->partition;
assert(cur_partition >= 0);
if(partition2_stream.find(cur_partition) != partition2_stream.end())
{
node->stream = partition2_stream[cur_partition];
}
else
{
node->stream = get_stream(info, node);
}
assert(node->stream >= 0);
record(info, node);
partition2_stream[cur_partition] = node->stream;
}
#ifdef MIGRAPHX_DEBUG_OPT
MIGRAPHX_DEBUG(dump("---After assigning stream---"));
MIGRAPHX_DEBUG(dump(sorted_nodes));
#endif
}
// Reorder the instructions ino topological order.
//
void pre_scheduling_impl::splice(std::list<dag_node*>& sorted_nodes)
{
if(sorted_nodes.size() <= 1)
return;
auto begin = sorted_nodes.begin();
auto iter = sorted_nodes.end();
instruction_ref insert_before = (*(--iter))->ins;
do
{
iter--;
insert_before = p_program->move_instruction((*iter)->ins, insert_before);
} while(iter != begin);
#ifdef MIGRAPHX_DEBUG_OPT
MIGRAPHX_DEBUG(dump("---After splice in pre-scheduling---"));
MIGRAPHX_DEBUG(dump_program());
#endif
}
// Annotate streams and events in the instruction. Insert set_stream
// instructions.
//
void pre_scheduling_impl::annotate(std::list<dag_node*>& sorted_nodes)
{
int event = 0;
int last_stream = -1;
for(auto&& node : sorted_nodes)
{
instruction_ref ins = node->ins;
if(instr2_stream.find(ins) == instr2_stream.end())
continue;
int stream = instr2_stream[ins];
ins->set_stream(stream);
if(last_stream != stream)
{
insert_instr.insert_stream(p_program, ins, stream);
last_stream = stream;
}
std::vector<int> events;
for(auto&& arg : ins->inputs())
{
if(instr2_stream.find(arg) == instr2_stream.end())
continue;
int arg_s = instr2_stream[arg];
if(arg_s == stream)
continue;
if(!has_mask(arg, record_event))
{
events.push_back(event);
insert_instr.insert_record_event(p_program, std::next(arg), event);
event++;
}
add_mask(arg, record_event);
add_mask(ins, wait_event);
}
for(auto&& i : events)
insert_instr.insert_wait_event(p_program, ins, i);
}
}
void pre_scheduling_impl::run()
{
std::size_t num_of_instrs = p_program->size();
if(num_of_instrs == 0)
return;
MIGRAPHX_DEBUG(dump("---Before pre-scheduling---"));
MIGRAPHX_DEBUG(dump_program());
nodes.resize(num_of_instrs);
compute_weights();
reorder();
}
void pre_scheduling_impl::verify()
{
std::unordered_map<instruction_ref, bool> visited;
for(auto ins : iterator_for(*p_program))
{
for(auto&& arg : ins->inputs())
{
if(visited.find(arg) == visited.end())
MIGRAPHX_THROW("Input not visited");
}
visited[ins] = true;
}
}
#ifdef MIGRAPHX_DEBUG_OPT
void pre_scheduling_impl::dump(const std::string& str) { std::cout << str << std::endl; }
void pre_scheduling_impl::dump_program() { std::cout << *p_program << std::endl; }
void pre_scheduling_impl::dump(std::list<dag_node*>& sorted_nodes)
{
for(auto&& node : sorted_nodes)
{
node->dump();
if(!node->ins->inputs().empty())
{
std::cout << " inputs: ";
for(auto&& arg : node->ins->inputs())
{
dag_node* def_node = instr2_node[arg];
std::cout << " @" << def_node->ins_ndx;
}
std::cout << std::endl;
}
}
}
void dag_node::dump()
{
std::cout << " @" << ins_ndx;
std::cout << " name: " << ins->name();
std::cout << " weight: " << weight;
std::cout << " weight_sum: " << weight_sum;
if(can_use_stream())
std::cout << " stream: " << stream;
std::cout << " partition: " << partition;
std::cout << " sched_cycle: " << sched_cycle;
std::cout << std::endl;
}
#endif
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_PRE_SCHEDULING_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_PRE_SCHEDULING_IMPL_HPP
#include <migraphx/common_header.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/insert_instruction.hpp>
namespace migraphx {
struct dag_node
{
dag_node()
{
weight = 0;
run_on_cpu = 0;
weight_sum = 0;
ins_ndx = -1;
first_child = nullptr;
stream = -1;
partition = -1;
sched_cycle = -1;
earliest_cycle = -1;
}
int weight;
int run_on_cpu;
int weight_sum;
int ins_ndx;
dag_node* first_child;
int stream;
int partition;
int sched_cycle;
int earliest_cycle = -1;
instruction_ref ins;
bool is_literal() const { return (ins->name() == "@literal"); }
bool can_use_stream() const { return (run_on_cpu == 0); }
#ifdef MIGRAPHX_DEBUG_OPT
void dump();
#endif
};
struct dag_partition
{
dag_partition()
{
num_of_partition = 0;
weight_sum.clear();
}
int create_partition()
{
weight_sum.push_back(0);
return num_of_partition++;
}
void add_weight(dag_node* node)
{
if(node->partition >= 0)
{
assert(node->partition < num_of_partition);
weight_sum[node->partition] += node->weight;
}
}
int num_of_partition;
std::vector<int> weight_sum;
};
struct stream_info
{
stream_info(int n) : num_of_streams(n)
{
max_cycle = 0;
next_cycles.clear();
for(auto stream = 0; stream < num_of_streams; ++stream)
next_cycles.push_back(0);
}
std::vector<int> next_cycles;
int num_of_streams;
int max_cycle;
};
enum instruction_mask : unsigned int
{
record_event = 0,
wait_event = 1
};
struct pre_scheduling_impl
{
pre_scheduling_impl(program* p,
std::function<std::pair<int, int>(const operation&)> w,
int n,
insert_instruction ins,
bool v)
: p_program(p),
weight_func(std::move(w)),
num_of_streams(n),
insert_instr(std::move(ins)),
enable_verify(v)
{
instr2_node.clear();
instr2_mask.clear();
instr2_stream.clear();
}
void schedule(std::list<dag_node*>&);
void compute_weights();
int get_stream(stream_info&, dag_node*) const;
void record(stream_info&, dag_node*);
void reorder();
void run();
void splice(std::list<dag_node*>&);
void annotate(std::list<dag_node*>&);
static bool compare_exit_nodes(dag_node* d1, dag_node* d2)
{
return (d1->weight_sum > d2->weight_sum);
}
struct weighted_topology_ordering
{
bool operator()(const dag_node* d1, const dag_node* d2) const
{
if(d1->weight_sum < d2->weight_sum)
{
// smaller weigth_sum is placed on top of the queue.
return false;
}
else if(d1->weight_sum > d2->weight_sum)
{
return true;
}
else
{
// smaller instrution index is placed on top of the queue,
return d1->ins_ndx > d2->ins_ndx;
}
}
};
struct post_schedule_ordering
{
bool operator()(const dag_node* d1, const dag_node* d2) const
{
if(d1->sched_cycle == d2->sched_cycle)
{
if(d1->stream == d2->stream)
{
// smaller instruction index on top of queue.
return d1->ins_ndx > d2->ins_ndx;
}
else
{
// smaller stream on top of queue.
return (d1->stream > d2->stream);
}
}
else
{
// smaller sched_cycle on top of queue.
return (d1->sched_cycle > d2->sched_cycle);
}
}
};
bool has_mask(instruction_ref ins, unsigned int m)
{
if(instr2_mask.find(ins) != instr2_mask.end())
{
unsigned int mask = instr2_mask[ins];
return ((mask & (1u << m)) != 0);
}
return false;
}
void add_mask(instruction_ref ins, unsigned int m)
{
unsigned int mask = (instr2_mask.find(ins) != instr2_mask.end()) ? instr2_mask[ins] : 0;
if((mask & (1u << m)) == 0)
instr2_mask[ins] = (mask + (1u << m));
}
void verify();
#ifdef MIGRAPHX_DEBUG_OPT
void dump(const std::string&);
void dump_program();
void dump(std::list<dag_node*>&);
#endif
static const int min_partition_threshold = 2;
private:
program* p_program;
std::function<std::pair<int, int>(const operation&)> weight_func;
int num_of_streams;
insert_instruction insert_instr;
std::vector<dag_node> nodes;
std::vector<dag_node*> exit_nodes;
std::unordered_map<instruction_ref, dag_node*> instr2_node;
std::unordered_map<instruction_ref, int> instr2_stream;
std::unordered_map<instruction_ref, unsigned int> instr2_mask;
dag_partition partition_info;
bool enable_verify;
};
} // namespace migraphx
#endif
......@@ -7,7 +7,6 @@
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/pass_config.hpp>
#include <iostream>
#include <sstream>
#include <algorithm>
......@@ -16,9 +15,6 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_COMPILE)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL)
struct program_impl
{
// A list is used to keep references to an instruction stable
......@@ -54,8 +50,7 @@ static void print_instruction(std::ostream& os,
}
os << ")";
}
if(ins->get_stream() >= 0)
os << "(stream=" << ins->get_stream() << ")";
os << " -> " << ins->get_shape();
}
......@@ -109,11 +104,9 @@ instruction_ref program::insert_instruction(instruction_ref ins,
args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction");
assert(not starts_with(op.name(), "@"));
// TODO: Use move
shape r = compute_shape(op, args);
auto result = impl->instructions.insert(ins, {op, r, std::move(args)});
instruction::backreference(result);
// assert(result->inputs() == args);
assert(result->valid(begin()));
return result;
}
......@@ -332,8 +325,6 @@ void program::finalize()
}
}
void program::finish() { this->impl->ctx.finish(); }
template <class F>
argument generic_eval(const program& p,
context& ctx,
......@@ -345,7 +336,6 @@ argument generic_eval(const program& p,
results.reserve(p.size() * 2);
std::vector<argument> values;
values.reserve(16);
for(auto ins : iterator_for(p))
{
if(ins->name() == "@literal")
......@@ -378,7 +368,6 @@ argument generic_eval(const program& p,
assert(results.find(i) != results.end());
return results[i];
});
results.emplace(ins, trace(ins, [&] {
return ins->get_operator().compute(ctx, ins->get_shape(), values);
}));
......@@ -516,6 +505,16 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
void program::debug_print() const { std::cout << *this << std::endl; }
void program::debug_print(instruction_ref ins) const
{
if(ins == this->end())
{
std::cout << "End instruction" << std::endl;
return;
}
if(not has_instruction(ins))
{
std::cout << "Instruction not part of program" << std::endl;
return;
}
std::stringstream ss;
print_program(ss, *this, [&](auto x, auto&& names) {
if(x == ins)
......@@ -538,6 +537,11 @@ void program::dry_run(std::unordered_map<std::string, argument> params) const
generic_eval(*this, ctx, std::move(params), [](auto&&...) { return argument{}; });
}
void program::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const
{
print_program(os, *this, [&](auto ins, auto&&) { a(ins); });
}
bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); }
std::ostream& operator<<(std::ostream& os, const program& p)
......@@ -545,5 +549,6 @@ std::ostream& operator<<(std::ostream& os, const program& p)
print_program(os, p, [](auto&&...) {});
return os;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/schedule.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_map>
#include <unordered_set>
#include <set>
#include <deque>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
auto get_inputs()
{
return [](auto i) { return i->inputs(); };
}
auto get_outputs()
{
return [](auto i) { return i->outputs(); };
}
struct stream_info
{
std::unordered_map<instruction_ref, std::size_t> ins2stream;
std::unordered_map<instruction_ref, std::size_t> weights;
std::unordered_map<instruction_ref, std::size_t> iweights;
void accumulate_weights(instruction_ref last, const schedule_model& model)
{
fix<std::size_t>([&](auto self, auto ins) -> std::size_t {
if(not contains(weights, ins))
{
std::size_t weight = 0;
auto&& op = ins->get_operator();
if(not is_context_free(op) and op.name()[0] != '@')
weight = model.weight(op);
iweights[ins] = weight;
weights[ins] =
std::accumulate(ins->inputs().begin(),
ins->inputs().end(),
weight,
[&](std::size_t w, instruction_ref i) { return w + self(i); });
}
return weights[ins];
})(last);
}
std::vector<instruction_ref>::iterator sort_args(std::vector<instruction_ref>& args)
{
if(args.size() < 2)
{
return args.end();
}
const std::size_t min_partition_threshold = 2;
auto compare = by(std::greater<>{}, [&](auto x) {
return std::make_tuple(this->weights[x], x->inputs().size());
});
std::sort(args.begin(), args.end(), compare);
auto it = std::lower_bound(std::next(args.begin()),
args.end(),
min_partition_threshold,
[&](auto i, std::size_t w) { return this->weights[i] > w; });
assert(it == args.end() or this->weights[*it] <= min_partition_threshold);
assert(it == args.end() or std::prev(it) == args.begin() or
this->weights[*std::prev(it)] > min_partition_threshold);
return it;
}
struct partition
{
std::size_t weight = 0;
std::vector<instruction_ref> instructions{};
void add(instruction_ref ins, std::size_t w)
{
weight += w;
instructions.push_back(ins);
}
};
void assign_streams(program& p, std::size_t n)
{
partition critical;
std::unordered_map<instruction_ref, std::deque<partition>> partitions;
partitions.reserve(weights.size());
fix([&](auto self, auto ins, auto& part) {
assert(ins != p.end());
if(contains(partitions, ins))
return;
assert(p.has_instruction(ins));
// Add an entry so we know the instruction was visited
partitions[ins];
part.add(ins, this->iweights[ins]);
auto args = ins->inputs();
auto threshold_it = this->sort_args(args);
if(not args.empty())
{
assert(threshold_it != args.begin());
self(args.front(), part);
for(auto i : range(std::next(args.begin()), threshold_it))
{
partitions[ins].emplace_back();
self(i, partitions[ins].back());
}
for(auto i : range(threshold_it, args.end()))
{
self(i, part);
}
}
// Sort instructions
p.move_instruction(ins, p.end());
})(std::prev(p.end()), critical);
// Set the critical partition to stream 0
set_stream(critical, 0);
std::vector<std::size_t> streams(n - 1);
// Assign streams for the other partitions
for(auto&& ins_part : partitions)
{
std::sort(
ins_part.second.begin(), ins_part.second.end(), by(std::greater<>{}, [](auto&& x) {
return std::make_tuple(x.weight, x.instructions.size());
}));
for(auto&& part : ins_part.second)
{
auto stream = std::min_element(streams.begin(), streams.end()) - streams.begin();
set_stream(part, stream + 1);
streams[stream] += part.weight;
}
}
}
void set_stream(const partition& p, std::size_t n)
{
for(auto ins : p.instructions)
if(iweights[ins] > 0)
set_stream(ins, n);
}
void set_stream(instruction_ref ins, std::size_t n)
{
assert(iweights[ins] > 0);
ins2stream[ins] = n;
}
std::size_t get_stream(instruction_ref ins) const { return ins2stream.at(ins); }
bool has_stream(instruction_ref ins) const { return contains(ins2stream, ins); }
template <class F>
bool different(F f, std::size_t stream) const
{
bool result = false;
f([&](auto s) {
if(s != stream)
{
result = true;
return false;
}
// cppcheck-suppress uselessAssignmentArg
stream = s;
return true;
});
return result;
}
template <class F>
bool different(F f) const
{
bool result = false;
f([&](auto s) {
result = this->different(f, s);
return false;
});
return result;
}
template <class Selector>
auto get_streams_from(instruction_ref start, Selector select) const
{
return [=](auto f) {
return fix<bool>([&](auto self, auto ins) {
for(auto i : select(ins))
{
if(iweights.at(i) == 0)
{
if(not self(i))
return false;
}
else
{
if(not f(this->get_stream(i)))
return false;
}
}
return true;
})(start);
};
}
std::unordered_set<std::size_t> get_streams(instruction_ref ins) const
{
if(has_stream(ins))
return {get_stream(ins)};
std::unordered_set<std::size_t> result;
get_streams_from(ins, get_inputs())([&](auto s) {
result.insert(s);
return true;
});
return result;
}
template <class... Ts>
bool is_merge_point(instruction_ref ins, Ts... xs) const
{
return different(get_streams_from(ins, get_inputs()), xs...);
}
template <class... Ts>
bool is_split_point(instruction_ref ins, Ts... xs) const
{
return different(get_streams_from(ins, get_outputs()), xs...);
}
std::vector<instruction_ref> get_recorded_instructions(instruction_ref start)
{
std::vector<instruction_ref> result;
std::unordered_map<std::size_t, instruction_ref> m;
fix([&](auto self, auto ins) {
for(auto i : ins->inputs())
{
if(iweights.at(i) == 0)
{
self(i);
continue;
}
auto stream = this->get_stream(i);
if(not contains(m, stream))
m[stream] = i;
else
m[stream] = std::min(m[stream], i, by(std::less<>{}, [&](auto x) {
return std::distance(x, start);
}));
}
})(start);
std::transform(
m.begin(), m.end(), std::back_inserter(result), [](auto&& p) { return p.second; });
return result;
}
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>>
find_concurrent_instructions(program& p)
{
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result;
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from;
result.reserve(p.size());
merge_from.reserve(p.size());
for(auto ins : reverse_iterator_for(p))
{
for(auto&& arg : ins->outputs())
{
if(is_merge_point(arg))
merge_from[ins].insert(arg);
merge_from[ins].insert(merge_from[arg].begin(), merge_from[arg].end());
}
auto streams = this->get_streams(ins);
// Collect concur instructions for each merge point.
for(auto& merge : merge_from[ins])
{
for(auto stream : streams)
{
if(result[merge].size() <= stream)
result[merge].resize(stream + 1);
auto&& r = result[merge][stream];
r.push_back(ins);
// Copy inputs if they dont have a stream(and are not a builtin and context
// free). Inputs without a stream can have a implicit dependency
std::copy_if(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(r),
[&](auto x) {
return not this->has_stream(x) and
not is_context_free(x->get_operator()) and
x->name().front() != '@';
});
}
}
}
return result;
}
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>
get_conflicts(program& p)
{
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> conflict_table;
auto concur_ins = this->find_concurrent_instructions(p);
for(auto&& merge : concur_ins)
{
dfor(merge.second.size(), merge.second.size())([&](auto i, auto j) {
if(i == j)
return;
for(auto ins1 : merge.second[i])
{
auto p1 = std::distance(ins1, merge.first);
for(auto ins2 : merge.second[j])
{
if(ins1 == ins2)
continue;
auto p2 = std::distance(ins2, merge.first);
// The smaller distance means the instruction occurs later
if(p1 > p2)
conflict_table[ins2].insert(ins1);
else
conflict_table[ins1].insert(ins2);
}
}
});
}
// Remove duplicates
for(auto&& ip : conflict_table)
{
auto ins1 = ip.first;
for(auto ins2 : ip.second)
if(contains(conflict_table[ins2], ins1))
conflict_table[ins2].erase(ins1);
}
return conflict_table;
}
};
void schedule::apply(program& p) const
{
stream_info si;
auto last = std::prev(p.end());
si.accumulate_weights(last, model);
si.assign_streams(p, model.concurrency());
if(enabled(MIGRAPHX_TRACE_COMPILE{}))
{
p.annotate(std::cout, [&](auto ins) {
std::cout << ":";
std::cout << " weight=" << si.weights.at(ins);
std::cout << " input={";
si.get_streams_from(ins, get_inputs())([&](auto s) {
std::cout << s << ",";
return true;
});
std::cout << "}";
if(si.has_stream(ins))
std::cout << " stream=" << si.get_stream(ins);
});
std::cout << std::endl;
}
// Schedule instructions
std::size_t wait_id = 0;
std::unordered_map<instruction_ref, std::size_t> ins2wait;
std::unordered_map<std::size_t, std::unordered_set<std::size_t>> waited_for;
std::unordered_map<instruction_ref, std::unordered_set<std::size_t>> ins2waited;
ins2wait.reserve(p.size());
ins2waited.reserve(p.size());
for(auto ins : iterator_for(p))
{
// Only schedule instructions that have a stream
if(not si.has_stream(ins))
continue;
assert(si.weights[ins] > 0);
// Schedule instruction on the stream
auto stream = si.get_stream(ins);
assert(stream < model.concurrency());
model.sched(p, ins, stream);
// Insert wait instructions
if(si.is_merge_point(ins, stream))
{
for(auto i : si.get_recorded_instructions(ins))
{
if(not si.has_stream(i))
continue;
auto istream = si.get_stream(i);
if(stream == istream)
continue;
// Create a new event if it hasn't been recorded
if(not contains(ins2wait, i))
{
ins2wait[i] = wait_id;
model.record(p, i, wait_id);
wait_id++;
}
auto w = ins2wait.at(i);
// If we already waited for the event on this stream then dont
// insert another wait event
if(not contains(waited_for[stream], w))
model.wait(p, ins, w);
// Store the event as waited
waited_for[stream].insert(w);
// Store all wait events that have been waited on prior to the recorded instruction
waited_for[stream].insert(ins2waited[i].begin(), ins2waited[i].end());
}
}
// Store wait events that have already been waited on
if(si.is_split_point(ins, stream))
{
ins2waited[ins] = waited_for[stream];
}
}
// Add memory conflicts
auto conflict_table = si.get_conflicts(p);
for(auto&& ip : conflict_table)
{
if(ip.second.empty())
continue;
std::vector<instruction_ref> args;
args.push_back(ip.first);
args.insert(args.end(), ip.second.begin(), ip.second.end());
p.insert_instruction(std::next(ip.first), op::identity{}, args);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -11,6 +11,7 @@ struct context
{
void finish() const {}
};
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -64,6 +64,7 @@ add_library(migraphx_gpu
pad.cpp
gather.cpp
lrn.cpp
schedule_model.cpp
)
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
rocm_clang_tidy_check(migraphx_gpu)
......
......@@ -21,19 +21,21 @@ argument miopen_convolution::compute(context& ctx,
float alpha = 1;
float beta = 0;
miopenConvolutionForward(ctx.get_stream().get_miopen(),
&alpha,
x_desc.get(),
args[0].implicit(),
w_desc.get(),
args[1].implicit(),
cd.get(),
algo,
&beta,
y_desc.get(),
args[3].implicit(),
args[2].implicit(),
args[2].get_shape().bytes());
auto status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
&alpha,
x_desc.get(),
args[0].implicit(),
w_desc.get(),
args[1].implicit(),
cd.get(),
algo,
&beta,
y_desc.get(),
args[3].implicit(),
args[2].implicit(),
args[2].get_shape().bytes());
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("Running convolution failed");
return args[3];
}
......@@ -89,8 +91,11 @@ void miopen_convolution::finalize(context& ctx,
{
if(handle == ctx.get_stream().get_miopen())
return;
// TODO: Check that workspace hasn't changed
compile(ctx, output_shape, std::move(inputs));
// Check that workspace hasn't changed
auto size = inputs.at(2).bytes();
auto ws = compile(ctx, output_shape, std::move(inputs));
if(ws.bytes() > size)
MIGRAPHX_THROW("Workspace has changed during finalization.");
}
} // namespace gpu
......
......@@ -10,12 +10,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
void gpu_sync()
{
hipDeviceSynchronize();
hipCtxSynchronize();
}
using hip_ptr = MIGRAPHX_MANAGE_PTR(void, hipFree);
std::string hip_error(int error) { return hipGetErrorString(static_cast<hipError_t>(error)); }
......@@ -105,6 +99,8 @@ void set_device(std::size_t id)
MIGRAPHX_THROW("Error setting device");
}
void gpu_sync() { hipDeviceSynchronize(); }
void copy_to_gpu(const argument& src, const argument& dst)
{
std::size_t src_size = src.get_shape().bytes();
......
......@@ -13,13 +13,17 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NULL_STREAM)
using hip_event_ptr = MIGRAPHX_MANAGE_PTR(hipEvent_t, hipEventDestroy);
struct hip_device
{
using hip_event_ptr = MIGRAPHX_MANAGE_PTR(hipEvent_t, hipEventDestroy);
hip_device() {}
hip_device() { add_stream(); }
hip_device(std::size_t id, std::size_t n) : device_id(id) { add_streams(n); }
hip_device(std::size_t id, std::size_t n) : device_id(id)
{
for(std::size_t i = 0; i < n; i++)
add_stream();
}
struct stream
{
......@@ -35,7 +39,6 @@ struct hip_device
{
hipStream_t result = nullptr;
auto status = hipStreamCreateWithFlags(&result, hipStreamNonBlocking);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to allocate stream");
return hip_stream_ptr{result};
......@@ -80,12 +83,6 @@ struct hip_device
return rbhandle.get();
}
void sync() const
{
if(s != nullptr)
hipStreamSynchronize(s.get());
}
void wait(hipEvent_t event)
{
setup();
......@@ -109,56 +106,22 @@ struct hip_device
shared<rocblas_handle_ptr> rbhandle = nullptr;
};
static hip_event_ptr create_event()
{
hipEvent_t event;
auto status = hipEventCreateWithFlags(&event, hipEventDisableTiming);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to creat event");
return hip_event_ptr{event};
}
void add_streams(std::size_t num_of_streams)
{
assert(streams.empty());
for(int i = 0; i < num_of_streams; ++i)
streams.emplace_back(device_id);
}
std::size_t nstreams() const { return streams.size(); }
void add_stream() { streams.emplace_back(device_id); }
stream& get_stream() { return streams.at(current_stream); }
void set_stream(std::size_t n) { current_stream = n; }
void create_events(std::size_t num_of_events)
{
for(int i = events.size(); i < num_of_events + 1; ++i)
events.emplace_back(create_event());
}
void record_event(std::size_t event)
{
streams.at(current_stream).record(events.at(event).get());
}
stream& get_stream(std::size_t n) { return streams.at(n); }
void wait_event(std::size_t event) { streams.at(current_stream).wait(events.at(event).get()); }
void set_stream(std::size_t n) { current_stream = n; }
void check_events(std::size_t n) const
{
if(n > events.size())
MIGRAPHX_THROW("The number of waits exceed the number of records.");
}
std::size_t nstreams() const { return streams.size(); }
void sync() const
{
for(auto&& stream : streams)
stream.sync();
}
std::size_t stream_id() const { return current_stream; }
private:
std::size_t device_id = 0;
std::size_t current_stream = 0;
std::vector<stream> streams;
std::vector<shared<hip_event_ptr>> events;
};
struct context
......@@ -168,12 +131,6 @@ struct context
{
}
const hip_device& get_current_device() const
{
assert(current_device != nullptr);
return *current_device;
}
hip_device& get_current_device()
{
assert(current_device != nullptr);
......@@ -181,26 +138,34 @@ struct context
}
hip_device::stream& get_stream() { return get_current_device().get_stream(); }
hip_device::stream& get_stream(std::size_t n) { return get_current_device().get_stream(n); }
void set_stream(std::size_t n) { get_current_device().set_stream(n); }
void create_events(std::size_t num_of_events)
{
get_current_device().create_events(num_of_events);
for(std::size_t i = events.size(); i < num_of_events + 1; ++i)
events.emplace_back(create_event());
}
void check_events(std::size_t n) const { get_current_device().check_events(n); }
void record_event(std::size_t event) { get_current_device().record_event(event); }
void wait_event(std::size_t event) { get_current_device().wait_event(event); }
hipEvent_t get_event(std::size_t i) const { return events.at(i).get(); }
std::vector<argument> literals{};
void finish() const
void finish() const { gpu_sync(); }
static hip_event_ptr create_event()
{
get_current_device().sync();
gpu_sync();
hipEvent_t event;
auto status = hipEventCreateWithFlags(&event, hipEventDisableTiming);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to create event");
return hip_event_ptr{event};
}
private:
// TODO: Make this a vector to support multiple devices
std::shared_ptr<hip_device> current_device;
std::vector<shared<hip_event_ptr>> events;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
#ifndef MIGRAPHX_GUARD_RTGLIB_FIND_CONCUR_GPU_HPP
#define MIGRAPHX_GUARD_RTGLIB_FIND_CONCUR_GPU_HPP
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
#include <migraphx/dom_info.hpp>
#include <migraphx/common_header.hpp>
#include <unordered_map>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct find_concur_gpu
{
void get_concur(program* p,
int num_of_streams,
std::unordered_map<const instruction*,
std::vector<std::vector<const instruction*>>>& concur_instrs,
std::unordered_map<const instruction*, int>& instr2_points) const
{
dom_info info(p);
info.compute_dom(true);
info.propagate_splits(num_of_streams, concur_instrs, instr2_points);
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_INSERT_INSTRUCTION_GPU_HPP
#define MIGRAPHX_GUARD_RTGLIB_INSERT_INSTRUCTION_GPU_HPP
#include <migraphx/instruction_ref.hpp>
#include <migraphx/program.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/event.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct insert_instruction_gpu
{
void insert_create_events(program* p, instruction_ref ins, int num_of_events)
{
// p->insert_instruction(ins, create_events{num_of_events});
}
void insert_record_event(program* p, instruction_ref ins, int event)
{
p->insert_instruction(ins, record_event{event});
}
void insert_wait_event(program* p, instruction_ref ins, int event)
{
p->insert_instruction(ins, wait_event{event});
}
void insert_stream(program* p, instruction_ref ins, int stream)
{
p->insert_instruction(ins, set_stream{stream});
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_MIOPEN_MACHINE_MODEL_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_MIOPEN_MACHINE_MODEL_HPP
#include <string>
#include <unordered_map>
#include <migraphx/pass_config.hpp>
#include <migraphx/operation.hpp>
namespace migraphx {
namespace gpu {
struct op_info
{
op_info()
{
// First in pair denotes weight. Second in pair tells
// that the instruction is run ONLY on CPU.
weight_map["convolution"] = std::make_pair(4, 0);
weight_map["pooling"] = std::make_pair(2, 0);
weight_map["gemm"] = std::make_pair(2, 0);
weight_map["broadcast"] = std::make_pair(1, 1);
weight_map["multibroadcast"] = std::make_pair(1, 1);
weight_map["contiguous"] = std::make_pair(1, 1);
weight_map["transpose"] = std::make_pair(1, 1);
weight_map["load"] = std::make_pair(1, 1);
weight_map["@param"] = std::make_pair(1, 1);
weight_map["@literal"] = std::make_pair(1, 1);
weight_map["hip::load_literal"] = std::make_pair(1, 1);
weight_map["hip::allocate"] = std::make_pair(0, 1);
weight_map["@outline"] = std::make_pair(0, 1);
weight_map["slice"] = std::make_pair(1, 1);
weight_map["squeeze"] = std::make_pair(1, 1);
weight_map["unsqueeze"] = std::make_pair(1, 1);
weight_map["gpu::convolution"] = std::make_pair(4, 0);
weight_map["gpu::conv_bias_relu"] = std::make_pair(4, 0);
weight_map["gpu::pooling"] = std::make_pair(2, 0);
weight_map["gpu::gemm"] = std::make_pair(2, 0);
weight_map["gpu::concat"] = std::make_pair(1, 0);
weight_map["hip::add_relu"] = std::make_pair(2, 0);
}
std::pair<int, int> operator()(const operation& op)
{
if(weight_map.find(op.name()) != weight_map.end())
{
return weight_map[op.name()];
}
else
{
return std::make_pair(1, is_context_free(op) ? 1 : 0);
}
}
std::unordered_map<std::string, std::pair<int, int>> weight_map;
};
} // namespace gpu
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_GPU_SCHEDULE_MODEL_HPP
#define MIGRAPHX_GUARD_RTGLIB_GPU_SCHEDULE_MODEL_HPP
#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct operation;
namespace gpu {
struct schedule_model
{
std::size_t streams = 0;
std::size_t concurrency() const;
void sched(program& p, instruction_ref ins, std::size_t n) const;
void wait(program& p, instruction_ref ins, std::size_t wait_id) const;
void record(program& p, instruction_ref ins, std::size_t wait_id) const;
std::size_t weight(const operation& op) const;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_EVENT_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_EVENT_HPP
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/schedule_model.hpp>
#include <migraphx/gpu/context.hpp>
#include <utility>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operation.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct create_events
{
int num_of_events = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.num_of_events, "event"));
}
std::string name() const { return "gpu::create_events"; }
shape compute_shape(const std::vector<shape>&) const { return {}; }
argument compute(context& ctx, const shape&, const std::vector<argument>&) const
{
ctx.create_events(num_of_events);
return {};
}
};
struct record_event
{
int event = -1;
std::size_t event = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
......@@ -42,20 +21,19 @@ struct record_event
argument compute(context& ctx, const shape&, const std::vector<argument>&) const
{
ctx.record_event(event);
ctx.get_stream().record(ctx.get_event(event));
return {};
}
void finalize(context& ctx, const shape&, std::vector<shape>)
void finalize(context& ctx, const shape&, const std::vector<shape>&)
{
assert(event >= 0);
ctx.create_events(event);
}
};
struct wait_event
{
int event = -1;
std::size_t event = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
......@@ -66,20 +44,14 @@ struct wait_event
argument compute(context& ctx, const shape&, const std::vector<argument>&) const
{
ctx.wait_event(event);
ctx.get_stream().wait(ctx.get_event(event));
return {};
}
void finalize(context& ctx, const shape&, std::vector<shape>)
{
assert(event >= 0);
ctx.check_events(event);
}
};
struct set_stream
{
int stream = -1;
std::size_t stream = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
......@@ -90,14 +62,66 @@ struct set_stream
argument compute(context& ctx, const shape&, const std::vector<argument>&) const
{
assert(stream >= 0);
ctx.set_stream(stream);
return {};
}
void finalize(context& ctx, const shape&, const std::vector<shape>&) { ctx.set_stream(stream); }
};
std::size_t schedule_model::concurrency() const { return streams; }
void schedule_model::sched(program& p, instruction_ref ins, std::size_t n) const
{
auto last_stream = std::find_if(std::make_reverse_iterator(ins),
std::make_reverse_iterator(p.begin()),
[&](auto&& i) { return i.name() == "gpu::set_stream"; });
if(last_stream != std::make_reverse_iterator(p.begin()))
{
auto&& op = any_cast<set_stream>(last_stream->get_operator());
// If the same stream was set earlier then skip
if(op.stream == n)
return;
}
p.insert_instruction(ins, set_stream{n});
}
void schedule_model::wait(program& p, instruction_ref ins, std::size_t wait_id) const
{
p.insert_instruction(ins, wait_event{wait_id});
}
void schedule_model::record(program& p, instruction_ref ins, std::size_t wait_id) const
{
p.insert_instruction(std::next(ins), record_event{wait_id});
}
static std::unordered_map<std::string, std::size_t> create_weight_map()
{
return {
{"hip::load_literal", 0},
{"hip::allocate", 0},
{"gpu::convolution", 4},
{"gpu::conv_bias_relu", 4},
{"gpu::pooling", 2},
{"gpu::gemm", 2},
{"gpu::concat", 1},
{"hip::add_relu", 2},
};
}
static const std::unordered_map<std::string, std::size_t>& weight_map()
{
static std::unordered_map<std::string, std::size_t> m = create_weight_map();
return m;
}
std::size_t schedule_model::weight(const operation& op) const
{
if(weight_map().count(op.name()) == 0)
{
return 1;
}
return weight_map().at(op.name());
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment