Commit e747cf2e authored by mei-ye's avatar mei-ye
Browse files

coloring

parent ad17c504
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <set> #include <set>
#include <list> #include <list>
#include <vector> #include <vector>
#include <queue>
#define DEBUG_OPT #define DEBUG_OPT
......
...@@ -2,23 +2,69 @@ ...@@ -2,23 +2,69 @@
namespace migraph { namespace migraph {
void memory_coloring_impl::run() void memory_coloring_impl::run()
{
build();
if (num_of_lives != 0) {
DEBUG(dump("---Before memory coloring---"));
DEBUG(dump());
// Coloring
while (!alloc_queue.empty()) {
T_live_interval* interval = alloc_queue.top();
allocate(interval);
alloc_queue.pop();
}
for (int i = 0; i < num_of_lives; ++i)
free(live_intervals[i]);
}
}
bool memory_coloring_impl::allocate(T_live_interval* interval)
{
shape s = interval->result;
int size = s.bytes();
std::size_t element_size = size / s.elements();
T_live_range& segment = interval->segment;
int vn = segment.vn;
std::priority_queue<T_live_range*, std::vector<T_live_range*>, ordering> conflict_queue;
if (conflict_table.find(vn) != conflict_table.end()) {
std::set<int>& vn_set = conflict_table[vn];
for (auto iter = vn_set.begin(), end = vn_set.end(); iter != end; ++iter) {
T_live_range* range = live_ranges[*iter];
if (range->offset != -1)
conflict_queue.push(range);
}
}
int offset = 0;
while (!conflict_queue.empty()) {
T_live_range* range = conflict_queue.top();
int cur_offset = range->offset;
if ((cur_offset > offset) && (cur_offset - offset) >= size) {
break;
}
offset = cur_offset + range->size;
if ((offset % element_size) != 0)
offset += (element_size - (offset % element_size));
conflict_queue.pop();
}
segment.offset = offset;
return true;
}
void memory_coloring_impl::build()
{ {
int num_of_instrs = p_program->get_size(); int num_of_instrs = p_program->get_size();
if (num_of_instrs == 0) if (num_of_instrs == 0)
return; return;
DEBUG(dump("---Before memory coloring---"));
int cur_points = num_of_instrs * 2; int cur_points = num_of_instrs * 2;
instruction_ref iter = std::prev(p_program->end()); instruction_ref iter = std::prev(p_program->end());
instruction_ref begin = p_program->begin(); instruction_ref begin = p_program->begin();
std::vector<T_live_interval*> live_intervals;
std::vector<instruction_ref> dead_instrs; std::vector<instruction_ref> dead_instrs;
std::list<T_live_interval*> active_queue;
int num_of_lives = 0;
std::unordered_map<const instruction*, T_live_interval*> instr2Live; std::unordered_map<const instruction*, T_live_interval*> instr2Live;
std::set<int> live_set; std::set<int> live_set;
T_live_interval* next_def = nullptr; T_live_interval* next_def = nullptr;
live_intervals.reserve(num_of_instrs); // Build live intervals.
do { do {
const instruction* p_iter = &(*iter); const instruction* p_iter = &(*iter);
T_live_interval* def_interval = nullptr; T_live_interval* def_interval = nullptr;
...@@ -27,19 +73,18 @@ void memory_coloring_impl::run() ...@@ -27,19 +73,18 @@ void memory_coloring_impl::run()
def_interval = instr2Live[p_iter]; def_interval = instr2Live[p_iter];
bool isLit = isLiteral(iter); bool isLit = isLiteral(iter);
if (isAllocate(iter) || isLit) { if (isAllocate(iter) || isLit) {
T_live_range& range = def_interval->segments.front(); T_live_range& range = def_interval->segment;
range.begin = cur_points;
def_interval->result = iter->result; def_interval->result = iter->result;
def_interval->isLiteral = isLit; def_interval->isLiteral = isLit;
def_interval->next_enqueue_def = cur_points; alloc_queue.push(def_interval);
active_queue.push_front(def_interval); range.begin = cur_points;
range.size = (iter->result).bytes();
next_def = def_interval; next_def = def_interval;
live_set.erase(def_interval->id); live_set.erase(range.vn);
} }
} else if (!isParam(iter) && !isOutline(iter) && !isCheckContext(iter)) { } else if (!isParam(iter) && !isOutline(iter) && !isCheckContext(iter)) {
isDead = true; isDead = true;
} }
if (!iter->arguments.empty()) { if (!iter->arguments.empty()) {
for (auto&& arg : iter->arguments) { for (auto&& arg : iter->arguments) {
if (isParam(arg) || isOutline(arg)) { if (isParam(arg) || isOutline(arg)) {
...@@ -58,15 +103,18 @@ void memory_coloring_impl::run() ...@@ -58,15 +103,18 @@ void memory_coloring_impl::run()
int id = num_of_lives++; int id = num_of_lives++;
T_live_interval* interval = new live_interval(); T_live_interval* interval = new live_interval();
interval->id = id; interval->id = id;
interval->segments.push_back(T_live_range{-1, cur_points, -1}); interval->segment.end = cur_points;
interval->segment.vn = ++max_value_number;
interval->addUse(cur_points); interval->addUse(cur_points);
instr2Live[p_arg] = interval; instr2Live[p_arg] = interval;
live_set.insert(id); addConflicts(live_set, max_value_number);
live_set.insert(max_value_number);
live_intervals[id] = interval;
live_ranges[max_value_number] = &(interval->segment);
// Keep track of live intervals that are inactive when // Keep track of live intervals that are inactive when
// next_def is enqueued. // next_def is enqueued.
if (next_def != nullptr) if (next_def != nullptr)
next_def->inactive_afters.push_back(interval); next_def->inactive_afters.push_back(interval);
live_intervals[id] = interval;
} else { } else {
T_live_interval* interval = instr2Live[p_arg]; T_live_interval* interval = instr2Live[p_arg];
interval->addUse(cur_points); interval->addUse(cur_points);
...@@ -79,12 +127,8 @@ void memory_coloring_impl::run() ...@@ -79,12 +127,8 @@ void memory_coloring_impl::run()
cur_points -= 2; cur_points -= 2;
iter = std::prev(iter); iter = std::prev(iter);
} while (iter != begin); } while (iter != begin);
DEBUG(dump(live_intervals, num_of_lives));
for (int i = 0; i < num_of_lives; ++i)
free(live_intervals[i]);
} }
#ifdef DEBUG_OPT #ifdef DEBUG_OPT
void memory_coloring_impl::dump(std::string str) void memory_coloring_impl::dump(std::string str)
{ {
...@@ -92,7 +136,7 @@ void memory_coloring_impl::dump(std::string str) ...@@ -92,7 +136,7 @@ void memory_coloring_impl::dump(std::string str)
std::cout << *p_program << std::endl; std::cout << *p_program << std::endl;
} }
void memory_coloring_impl::dump(std::vector<T_live_interval*>& live_intervals, int num_of_lives) void memory_coloring_impl::dump()
{ {
if (num_of_lives > 0) { if (num_of_lives > 0) {
std::cout << "---live intervals ---" << std::endl; std::cout << "---live intervals ---" << std::endl;
...@@ -100,24 +144,31 @@ void memory_coloring_impl::dump(std::vector<T_live_interval*>& live_intervals, i ...@@ -100,24 +144,31 @@ void memory_coloring_impl::dump(std::vector<T_live_interval*>& live_intervals, i
T_live_interval* interval = live_intervals[i]; T_live_interval* interval = live_intervals[i];
interval->dump(); interval->dump();
} }
std::cout << "conflict table:" << std::endl;
for (int i = 0; i <= max_value_number; ++i) {
std::cout << " segment:" << i;
std::cout << " =>";
std::set<int>& table = conflict_table[i];
for (auto iter = table.begin(), end = table.end(); iter != end; ++iter) {
std::cout << (*iter) << ",";
}
}
std::cout << std::endl;
} }
} }
#define GET_INS_ENUM(x) (((x) >> 1) - 1) #define GET_INS_ENUM(x) (((x) >> 1) - 1)
void live_interval::dump() void live_interval::dump()
{ {
std::cout << "id:" << id; std::cout << "id:" << id;
for (auto iter = segments.begin(), end = segments.end(); iter != end; ++iter) { std::cout << " segment:" << segment.vn;
T_live_range& range = *iter; std::cout << " [" << GET_INS_ENUM(segment.begin) << ", " << GET_INS_ENUM(segment.end) << "]";
std::cout << " [" << GET_INS_ENUM(range.begin) << ", " << GET_INS_ENUM(range.end) << "]";
}
std::cout << " uses:"; std::cout << " uses:";
for (auto iter = use_points.begin(), end = use_points.end(); iter != end; ++iter) { for (auto iter = use_points.begin(), end = use_points.end(); iter != end; ++iter) {
int& use = *iter; int& use = *iter;
std::cout << " " << GET_INS_ENUM(use) << ","; std::cout << " " << GET_INS_ENUM(use) << ",";
} }
if (!inactive_afters.empty()) { if (!inactive_afters.empty()) {
std::cout << " inactivate:"; std::cout << " inactivate:";
for (auto iter = inactive_afters.begin(), end = inactive_afters.end(); iter != end; ++iter) { for (auto iter = inactive_afters.begin(), end = inactive_afters.end(); iter != end; ++iter) {
...@@ -127,6 +178,7 @@ void live_interval::dump() ...@@ -127,6 +178,7 @@ void live_interval::dump()
} }
if (isLiteral) if (isLiteral)
std::cout << " literal"; std::cout << " literal";
std::cout << " " << result;
std::cout << std::endl; std::cout << std::endl;
} }
#endif #endif
......
...@@ -5,44 +5,57 @@ ...@@ -5,44 +5,57 @@
namespace migraph { namespace migraph {
typedef struct live_range { typedef struct live_range {
explicit live_range(int b, int e, int o ) : begin(b), end(e), offset(o) {}; int begin; // begin point in the instruction stream.
int begin; int end; // end point in the instruction stream.
int end; int offset; // offset to base pointer of allocated memory trunk.
int offset; int vn; // value number that identifies this live_range.
int size; // size of required memory in bytes
} T_live_range; } T_live_range;
typedef struct live_interval { typedef struct live_interval {
explicit live_interval() { init(); } explicit live_interval() { init(); }
void addUse(int use) { use_points.push_front(use); }
void init() { void init() {
id = -1; isLiteral = false; id = -1; isLiteral = false;
segment = { -1, -1, -1, -1, 0};
} }
std::list <T_live_range> segments; void addUse(int use) { use_points.push_front(use); }
int get_begin() const { return segment.begin; }
int get_end() const { return segment.end; }
#ifdef DEBUG_OPT
void dump();
#endif
T_live_range segment;
int id; int id;
std::list<int> use_points; std::list<int> use_points;
// Live intervals that are inactive when this live interval is enqueued. // Live intervals that are inactive when this live interval is enqueued.
// can be used for live interval collapsing.
std::list<struct live_interval*> inactive_afters; std::list<struct live_interval*> inactive_afters;
// Next enqueue point for this live interval. It is not always
// equal to the begin if this live interval is rematerialized.
int next_enqueue_def;
shape result; shape result;
bool isLiteral; bool isLiteral;
#ifdef DEBUG_OPT
void dump();
#endif
} T_live_interval; } T_live_interval;
typedef struct occupant_range {
explicit occupant_range(int b, int e, T_live_interval* in)
: begin(b), end(e), interval(in) {};
int begin;
int end;
T_live_interval* interval;
} T_occupant_range;
struct memory_coloring_impl { struct memory_coloring_impl {
explicit memory_coloring_impl(program *p) : p_program(p) {} explicit memory_coloring_impl(program *p) : p_program(p)
{
live_intervals.clear();
live_ranges.clear();
conflict_table.clear();
num_of_lives = 0;
max_value_number = -1;
}
bool allocate(T_live_interval*);
void addConflicts(std::set<int>& live_set, int val)
{
for (auto iter = live_set.begin(), end = live_set.end(); iter != end; ++ iter) {
conflict_table[*iter].insert(val);
conflict_table[val].insert(*iter);
}
}
void build();
void run(); void run();
private: private:
bool isParam(const instruction_ref ins) { return ins->op.name() == "@param"; } bool isParam(const instruction_ref ins) { return ins->op.name() == "@param"; }
...@@ -54,13 +67,40 @@ struct memory_coloring_impl { ...@@ -54,13 +67,40 @@ struct memory_coloring_impl {
bool isOutline(const instruction_ref ins) { return ins->op.name() == "@outline"; } bool isOutline(const instruction_ref ins) { return ins->op.name() == "@outline"; }
bool isLiteral(const instruction_ref ins) { return ins->op.name() == "@literal"; } bool isLiteral(const instruction_ref ins) { return ins->op.name() == "@literal"; }
bool isCheckContext(const instruction_ref ins) { return ins->op.name() == "check_context"; } bool isCheckContext(const instruction_ref ins) { return ins->op.name() == "check_context"; }
#ifdef DEBUG_OPT #ifdef DEBUG_OPT
void dump(std::string); void dump(std::string);
void dump(std::vector<T_live_interval*>&, int); void dump();
#endif #endif
struct ordering {
bool operator() (const T_live_interval* I1, const T_live_interval* I2) const
{
int len1 = I1->get_end() - I1->get_begin();
int len2 = I2->get_end() - I2->get_begin();
if (len1 < len2)
return true;
else if (I1->result.bytes() < I2->result.bytes())
return true;
else
return I1->id > I2->id;
}
bool operator() (const T_live_range* I1, const T_live_range* I2) const
{
return (I1->offset > I2->offset);
}
};
program* p_program; program* p_program;
// Map live interval Id to live interval.
std::unordered_map<int, T_live_interval*> live_intervals;
// Map live range value number to live range.
std::unordered_map<int, T_live_range*> live_ranges;
// Map live range value number to a set of conflicting live ranges' value numbers.
std::unordered_map<int, std::set<int>> conflict_table;
// Priority queue for coloring.
std::priority_queue<T_live_interval*, std::vector<T_live_interval*>, ordering> alloc_queue;
int num_of_lives;
int max_value_number;
}; };
} // namespace migraph } // namespace migraph
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment