Commit 7b39fb38 authored by mei-ye's avatar mei-ye
Browse files

staging

parent d877a3fb
...@@ -12,8 +12,6 @@ struct memory_coloring ...@@ -12,8 +12,6 @@ struct memory_coloring
std::string name() const { return "memory coloring"; } std::string name() const { return "memory coloring"; }
void apply(program& p) const; void apply(program& p) const;
}; };
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -534,15 +534,13 @@ struct div : binary ...@@ -534,15 +534,13 @@ struct div : binary
std::string name() const { return "div"; } std::string name() const { return "div"; }
}; };
struct get_mem_ptr struct get_mem_ptr
{ {
std::string name() const { return "get_mem_ptr:" + std::to_string(offset); } std::string name() const { return "get_mem_ptr:" + std::to_string(offset); }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const { return inputs.at(1); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
return inputs.at(1); return {std::move(output_shape), args.at(0).data() + offset};
}
argument compute(context&, shape output_shape, std::vector<argument> args) const {
return {output_shape, args.at(0).data() + offset};
} }
std::size_t offset = 0; std::size_t offset = 0;
}; };
...@@ -550,14 +548,12 @@ struct get_mem_ptr ...@@ -550,14 +548,12 @@ struct get_mem_ptr
struct write_literal struct write_literal
{ {
std::string name() const { return "write_literal"; } std::string name() const { return "write_literal"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const { return inputs.at(2); }
argument compute(context&, shape, std::vector<argument>) const
{ {
return inputs.at(2);
}
argument compute(context&, shape output_shape, std::vector<argument> args) const {
assert(false); assert(false);
} }
}; };
struct outline struct outline
{ {
...@@ -573,7 +569,6 @@ struct outline ...@@ -573,7 +569,6 @@ struct outline
return {s, nullptr}; return {s, nullptr};
} }
}; };
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -93,7 +93,7 @@ struct program ...@@ -93,7 +93,7 @@ struct program
void perf_report(std::ostream& os, std::size_t n, parameter_map params) const; void perf_report(std::ostream& os, std::size_t n, parameter_map params) const;
int get_size() const; int get_size() const;
friend std::ostream& operator<<(std::ostream& os, const program& p); friend std::ostream& operator<<(std::ostream& os, const program& p);
friend bool operator==(const program& x, const program& y); friend bool operator==(const program& x, const program& y);
...@@ -102,7 +102,6 @@ struct program ...@@ -102,7 +102,6 @@ struct program
private: private:
std::unique_ptr<program_impl> impl; std::unique_ptr<program_impl> impl;
}; };
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraph/instruction.hpp> #include <migraph/instruction.hpp>
#include <migraph/operators.hpp> #include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp> #include <migraph/iterator_for.hpp>
#include <migraph/manage_ptr.hpp>
#include <set> #include <set>
#include <list> #include <list>
...@@ -19,4 +20,4 @@ ...@@ -19,4 +20,4 @@
#define DEBUG(s) #define DEBUG(s)
#endif // DEBUG_OPT #endif // DEBUG_OPT
#endif // MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP #endif // MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP
...@@ -2,10 +2,9 @@ ...@@ -2,10 +2,9 @@
#include "memory_coloring_impl.hpp" #include "memory_coloring_impl.hpp"
namespace migraph { namespace migraph {
void memory_coloring::apply(program &p) const void memory_coloring::apply(program& p) const
{ {
memory_coloring_impl opt(&p); memory_coloring_impl opt(&p);
opt.run(); opt.run();
} }
} // namespace migraph } // namespace migraph
...@@ -4,50 +4,55 @@ namespace migraph { ...@@ -4,50 +4,55 @@ namespace migraph {
void memory_coloring_impl::run() void memory_coloring_impl::run()
{ {
build(); build();
if (num_of_lives != 0) { if(num_of_lives != 0)
{
DEBUG(dump("---Before memory coloring---")); DEBUG(dump("---Before memory coloring---"));
DEBUG(dump_program()); DEBUG(dump_program());
DEBUG(dump_intervals()); DEBUG(dump_intervals());
// Coloring // Coloring
while (!alloc_queue.empty()) { while(!alloc_queue.empty())
T_live_interval* interval = alloc_queue.top(); {
interval_ptr interval = alloc_queue.top();
allocate(interval); allocate(interval);
alloc_queue.pop(); alloc_queue.pop();
} }
rewrite(); rewrite();
DEBUG(verify()); DEBUG(verify());
for (int i = 0; i < num_of_lives; ++i) {
free(live_intervals[i]);
}
} }
} }
bool memory_coloring_impl::allocate(T_live_interval* interval) bool memory_coloring_impl::allocate(interval_ptr interval)
{ {
shape s = interval->result; shape s = interval->result;
std::size_t size = s.bytes(); std::size_t size = s.bytes();
if (size == 0) if(size == 0)
return false; return false;
std::size_t element_size = size / s.elements(); std::size_t element_size = size / s.elements();
T_live_range& segment = interval->segment; T_live_range& segment = interval->segment;
int vn = segment.vn; int vn = segment.vn;
std::priority_queue<T_live_range*, std::vector<T_live_range*>, ordering> conflict_queue; std::priority_queue<T_live_range*, std::vector<T_live_range*>, ordering> conflict_queue;
std::unordered_map<long long, T_live_range*> offset2Live; std::unordered_map<long long, T_live_range*> offset2Live;
offset2Live.clear(); offset2Live.clear();
if (conflict_table.find(vn) != conflict_table.end()) { if(conflict_table.find(vn) != conflict_table.end())
{
std::set<int>& vn_set = conflict_table[vn]; std::set<int>& vn_set = conflict_table[vn];
for (auto iter = vn_set.begin(), end = vn_set.end(); iter != end; ++iter) { for(auto iter = vn_set.begin(), end = vn_set.end(); iter != end; ++iter)
{
T_live_range* range = live_ranges[*iter]; T_live_range* range = live_ranges[*iter];
long long offset = range->offset; long long offset = range->offset;
if (offset != InvalidOffset) { if(offset != InvalidOffset)
{
conflict_queue.push(range); conflict_queue.push(range);
if (offset2Live.find(offset) == offset2Live.end()) { if(offset2Live.find(offset) == offset2Live.end())
{
offset2Live[offset] = range; offset2Live[offset] = range;
} else { }
else
{
T_live_range* prev = offset2Live[offset]; T_live_range* prev = offset2Live[offset];
assert(prev->offset == offset); assert(prev->offset == offset);
if (prev->size < range->size) if(prev->size < range->size)
offset2Live[offset] = range; offset2Live[offset] = range;
} }
} }
...@@ -55,15 +60,18 @@ bool memory_coloring_impl::allocate(T_live_interval* interval) ...@@ -55,15 +60,18 @@ bool memory_coloring_impl::allocate(T_live_interval* interval)
} }
long long offset = 0; long long offset = 0;
while (!conflict_queue.empty()) { while(!conflict_queue.empty())
T_live_range* range = conflict_queue.top(); {
T_live_range* range = conflict_queue.top();
long long cur_offset = range->offset; long long cur_offset = range->offset;
if (offset2Live[cur_offset] == range) { if(offset2Live[cur_offset] == range)
if ((cur_offset > offset) && (cur_offset - offset) >= size) { {
if((cur_offset > offset) && (cur_offset - offset) >= size)
{
break; break;
} }
offset = cur_offset + range->size; offset = cur_offset + range->size;
if ((offset % element_size) != 0) if((offset % element_size) != 0)
offset += (element_size - (offset % element_size)); offset += (element_size - (offset % element_size));
} }
conflict_queue.pop(); conflict_queue.pop();
...@@ -77,98 +85,118 @@ bool memory_coloring_impl::allocate(T_live_interval* interval) ...@@ -77,98 +85,118 @@ bool memory_coloring_impl::allocate(T_live_interval* interval)
void memory_coloring_impl::build() void memory_coloring_impl::build()
{ {
int num_of_instrs = p_program->get_size(); int num_of_instrs = p_program->get_size();
if (num_of_instrs == 0) if(num_of_instrs == 0)
return; return;
int cur_points = num_of_instrs * 2; int cur_points = num_of_instrs * 2;
instruction_ref iter = std::prev(p_program->end()); instruction_ref iter = std::prev(p_program->end());
instruction_ref begin = p_program->begin(); instruction_ref begin = p_program->begin();
std::vector<instruction_ref> dead_instrs; std::vector<instruction_ref> dead_instrs;
std::set<int> live_set; std::set<int> live_set;
// Build live intervals. // Build live intervals.
do { do
const instruction* p_iter = &(*iter); {
T_live_interval* def_interval = nullptr; const instruction* p_iter = &(*iter);
bool isDead = false; interval_ptr def_interval = nullptr;
if (instr2Live.find(p_iter) != instr2Live.end()) { bool isDead = false;
def_interval = instr2Live[p_iter]; if(instr2Live.find(p_iter) != instr2Live.end())
bool isLit = isLiteral(iter); {
if (isAllocate(iter) || isLit) { def_interval = std::move(instr2Live[p_iter]);
T_live_range& range = def_interval->segment; bool isLit = isLiteral(iter);
def_interval->result = iter->result; if(isAllocate(iter) || isLit)
{
T_live_range& range = def_interval->segment;
def_interval->result = iter->result;
def_interval->isLiteral = isLit; def_interval->isLiteral = isLit;
alloc_queue.push(def_interval); alloc_queue.push(std::move(def_interval));
range.begin = cur_points; range.begin = cur_points;
range.size = (iter->result).bytes(); range.size = (iter->result).bytes();
live_set.erase(range.vn); live_set.erase(range.vn);
} }
} else if (!isParam(iter) && !isOutline(iter) && !isCheckContext(iter)) { }
else if(!isParam(iter) && !isOutline(iter) && !isCheckContext(iter))
{
isDead = true; isDead = true;
} }
int tieNdx = getInputTieNdx(iter); int tieNdx = getInputTieNdx(iter);
if (!iter->arguments.empty()) { if(!iter->arguments.empty())
{
int cnt = -1; int cnt = -1;
for (auto&& arg : iter->arguments) { for(auto&& arg : iter->arguments)
{
cnt++; cnt++;
if (isParam(arg) || isOutline(arg)) { if(isParam(arg) || isOutline(arg))
if (isOutputParam(arg)) {
if(isOutputParam(arg))
isDead = false; isDead = false;
continue; continue;
} }
const instruction* p_arg = &(*arg); const instruction* p_arg = &(*arg);
if (cnt == tieNdx) { if(cnt == tieNdx)
{
// input memory is used as this instruction's output. // input memory is used as this instruction's output.
// def is considered as use. Coalesce the live intervals. // def is considered as use. Coalesce the live intervals.
def_interval->addUse(cur_points); def_interval->addUse(cur_points);
instr2Live[p_arg] = def_interval; instr2Live[p_arg] = def_interval;
} else if (instr2Live.find(p_arg) == instr2Live.end()) { }
else if(instr2Live.find(p_arg) == instr2Live.end())
{
// First time see a use, create a live interval. // First time see a use, create a live interval.
int id = num_of_lives++; int id = num_of_lives++;
T_live_interval* interval = new live_interval(); interval_ptr interval(new live_interval());
interval->id = id; interval->id = id;
interval->segment.end = cur_points; interval->segment.end = cur_points;
interval->segment.vn = ++max_value_number; interval->segment.vn = ++max_value_number;
interval->addUse(cur_points); interval->addUse(cur_points);
instr2Live[p_arg] = interval; instr2Live[p_arg] = interval;
addConflicts(live_set, max_value_number); addConflicts(live_set, max_value_number);
live_set.insert(max_value_number); live_set.insert(max_value_number);
live_intervals[id] = interval; live_intervals[id] = std::move(interval);
live_ranges[max_value_number] = &(interval->segment); live_ranges[max_value_number] = &(interval->segment);
} else { }
T_live_interval* interval = instr2Live[p_arg]; else
{
interval_ptr interval = instr2Live[p_arg];
interval->addUse(cur_points); interval->addUse(cur_points);
DEBUG(assert(live_set.find(interval->id) != live_set.end())); DEBUG(assert(live_set.find(interval->id) != live_set.end()));
} }
} }
} }
if (isDead) if(isDead)
dead_instrs.push_back(iter); dead_instrs.push_back(iter);
cur_points -= 2; cur_points -= 2;
iter = std::prev(iter); iter = std::prev(iter);
} while (iter != begin); } while(iter != begin);
} }
void memory_coloring_impl::rewrite() void memory_coloring_impl::rewrite()
{ {
instruction_ref end = p_program->end(); instruction_ref end = p_program->end();
instruction_ref scratch_param = end; instruction_ref scratch_param = end;
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
dims.push_back(required_bytes/sizeof(float)); dims.push_back(required_bytes / sizeof(float));
shape s = {shape::float_type, dims}; shape s = {shape::float_type, dims};
scratch_param = p_program->add_parameter("scratch", s); scratch_param = p_program->add_parameter("scratch", s);
for (auto ins : iterator_for(*p_program)) { for(auto ins : iterator_for(*p_program))
{
const instruction* p_iter = &(*ins); const instruction* p_iter = &(*ins);
if (instr2Live.find(p_iter) != instr2Live.end()) { if(instr2Live.find(p_iter) != instr2Live.end())
T_live_interval* interval = instr2Live[p_iter]; {
if (interval->get_offset() == InvalidOffset) { interval_ptr interval = instr2Live[p_iter];
if(interval->get_offset() == InvalidOffset)
{
DEBUG(assert((interval->get_begin() == InvalidOffset) || DEBUG(assert((interval->get_begin() == InvalidOffset) ||
interval->result.bytes() == 0)); interval->result.bytes() == 0));
continue; continue;
} }
std::size_t offset = interval->get_offset(); std::size_t offset = interval->get_offset();
if (isAllocate(ins)) { if(isAllocate(ins))
p_program->replace_instruction(ins, get_mem_ptr{offset}, scratch_param, ins->arguments.at(0)); {
} else if (isLiteral(ins)) { p_program->replace_instruction(
auto pre = p_program->add_literal(ins->lit); ins, get_mem_ptr{offset}, scratch_param, ins->arguments.at(0));
}
else if(isLiteral(ins))
{
auto pre = p_program->add_literal(ins->lit);
auto index = p_program->add_literal(offset); auto index = p_program->add_literal(offset);
p_program->replace_instruction(ins, write_literal{}, scratch_param, index, pre); p_program->replace_instruction(ins, write_literal{}, scratch_param, index, pre);
} }
...@@ -177,33 +205,30 @@ void memory_coloring_impl::rewrite() ...@@ -177,33 +205,30 @@ void memory_coloring_impl::rewrite()
DEBUG(dump("---After rewrite---")); DEBUG(dump("---After rewrite---"));
DEBUG(dump_program()); DEBUG(dump_program());
} }
#ifdef DEBUG_OPT #ifdef DEBUG_OPT
void memory_coloring_impl::dump(std::string str) void memory_coloring_impl::dump(std::string str) { std::cout << str << std::endl; }
{
std::cout << str << std::endl;
}
void memory_coloring_impl::dump_program() void memory_coloring_impl::dump_program() { std::cout << *p_program << std::endl; }
{
std::cout << *p_program << std::endl;
}
void memory_coloring_impl::dump_intervals() void memory_coloring_impl::dump_intervals()
{ {
if (num_of_lives > 0) { if(num_of_lives > 0)
{
std::cout << "---live intervals ---" << std::endl; std::cout << "---live intervals ---" << std::endl;
for (int i = 0; i < num_of_lives; ++i) { for(int i = 0; i < num_of_lives; ++i)
T_live_interval* interval = live_intervals[i]; {
interval_ptr interval = live_intervals[i];
interval->dump(); interval->dump();
} }
std::cout << "---conflict table---" << std::endl; std::cout << "---conflict table---" << std::endl;
for (int i = 0; i <= max_value_number; ++i) { for(int i = 0; i <= max_value_number; ++i)
{
std::cout << " segment:" << i; std::cout << " segment:" << i;
std::cout << " =>"; std::cout << " =>";
std::set<int>& table = conflict_table[i]; std::set<int>& table = conflict_table[i];
for (auto iter = table.begin(), end = table.end(); iter != end; ++iter) { for(auto iter = table.begin(), end = table.end(); iter != end; ++iter)
{
std::cout << (*iter) << ","; std::cout << (*iter) << ",";
} }
} }
...@@ -213,20 +238,24 @@ void memory_coloring_impl::dump_intervals() ...@@ -213,20 +238,24 @@ void memory_coloring_impl::dump_intervals()
void memory_coloring_impl::verify() void memory_coloring_impl::verify()
{ {
if (num_of_lives > 0) { if(num_of_lives > 0)
for (int i = 0; i < num_of_lives; ++i) { {
T_live_interval* interval = live_intervals[i]; for(int i = 0; i < num_of_lives; ++i)
T_live_range& segment = interval->segment; {
if (segment.offset == InvalidOffset) interval_ptr interval = live_intervals[i];
T_live_range& segment = interval->segment;
if(segment.offset == InvalidOffset)
continue; continue;
int vn = segment.vn; int vn = segment.vn;
if (conflict_table.find(vn) != conflict_table.end()) { if(conflict_table.find(vn) != conflict_table.end())
{
std::set<int>& vn_set = conflict_table[vn]; std::set<int>& vn_set = conflict_table[vn];
for (auto iter = vn_set.begin(), end = vn_set.end(); iter != end; ++iter) { for(auto iter = vn_set.begin(), end = vn_set.end(); iter != end; ++iter)
{
T_live_range* range = live_ranges[*iter]; T_live_range* range = live_ranges[*iter];
if (range->offset == InvalidOffset) if(range->offset == InvalidOffset)
continue; continue;
if (!isDisjoin(*range, segment)) if(!isDisjoin(*range, segment))
assert(false); assert(false);
} }
} }
...@@ -235,34 +264,35 @@ void memory_coloring_impl::verify() ...@@ -235,34 +264,35 @@ void memory_coloring_impl::verify()
} }
#define GET_INS_ENUM(x) (((x) >> 1) - 1) #define GET_INS_ENUM(x) (((x) >> 1) - 1)
void live_range::dump() void live_range::dump()
{ {
std::cout << " segment:" << vn; std::cout << " segment:" << vn;
std::cout << " [" << GET_INS_ENUM(begin) << ", " << GET_INS_ENUM(end) << "]"; std::cout << " [" << GET_INS_ENUM(begin) << ", " << GET_INS_ENUM(end) << "]";
if (offset != InvalidOffset) { if(offset != InvalidOffset)
{
std::cout << " mem:"; std::cout << " mem:";
std::cout << " [" << offset << "," << offset + size - 1 << "]"; std::cout << " [" << offset << "," << offset + size - 1 << "]";
} }
std::cout << std::endl; std::cout << std::endl;
} }
void live_interval::dump() void live_interval::dump()
{ {
std::cout << "id:" << id; std::cout << "id:" << id;
segment.dump(); segment.dump();
std::cout << " uses:"; std::cout << " uses:";
for (auto iter = use_points.begin(), end = use_points.end(); iter != end; ++iter) { for(auto iter = use_points.begin(), end = use_points.end(); iter != end; ++iter)
{
int& use = *iter; int& use = *iter;
std::cout << " " << GET_INS_ENUM(use) << ","; std::cout << " " << GET_INS_ENUM(use) << ",";
} }
if (isLiteral) if(isLiteral)
std::cout << " literal"; std::cout << " literal";
std::cout << " " << result; std::cout << " " << result;
std::cout << std::endl; std::cout << std::endl;
} }
#endif #endif
} // namespace migraph } // namespace migraph
...@@ -5,33 +5,38 @@ ...@@ -5,33 +5,38 @@
namespace migraph { namespace migraph {
#define InvalidOffset -1 #define InvalidOffset -1
typedef struct live_range {
int begin; // begin point in the instruction stream. typedef struct live_range
int end; // end point in the instruction stream. {
int begin; // begin point in the instruction stream.
int end; // end point in the instruction stream.
long long offset; // offset to base pointer of allocated memory trunk. long long offset; // offset to base pointer of allocated memory trunk.
int vn; // value number that identifies this live_range. int vn; // value number that identifies this live_range.
long long size; // size of required memory in bytes long long size; // size of required memory in bytes
#ifdef DEBUG_OPT #ifdef DEBUG_OPT
void dump(); void dump();
#endif #endif
} T_live_range; } T_live_range;
typedef struct live_interval { typedef struct live_interval
explicit live_interval() { init(); } {
live_interval() { init(); }
void init() { ~live_interval() {}
id = -1; isLiteral = false; void init()
segment = { -1, -1, InvalidOffset, -1, 0}; {
id = -1;
isLiteral = false;
segment = {-1, -1, InvalidOffset, -1, 0};
} }
void addUse(int use) { use_points.push_front(use); } void addUse(int use) { use_points.push_front(use); }
int get_begin() const { return segment.begin; } int get_begin() const { return segment.begin; }
int get_end() const { return segment.end; } int get_end() const { return segment.end; }
long long get_offset() const { return segment.offset; } long long get_offset() const { return segment.offset; }
#ifdef DEBUG_OPT #ifdef DEBUG_OPT
void dump(); void dump();
#endif #endif
T_live_range segment; T_live_range segment;
int id; int id;
...@@ -41,21 +46,34 @@ typedef struct live_interval { ...@@ -41,21 +46,34 @@ typedef struct live_interval {
} T_live_interval; } T_live_interval;
struct memory_coloring_impl { // #define unique_interval_ptr std::unique_ptr<T_live_interval>
explicit memory_coloring_impl(program *p) : p_program(p) #define interval_ptr T_live_interval*
struct memory_coloring_impl
{
memory_coloring_impl(){}
memory_coloring_impl(program* p) : p_program(p)
{ {
init();
}
~memory_coloring_impl() {
for(int i = 0; i < num_of_lives; ++i)
free(live_intervals[i]);
}
void init() {
instr2Live.clear(); instr2Live.clear();
live_intervals.clear(); live_intervals.clear();
live_ranges.clear(); live_ranges.clear();
conflict_table.clear(); conflict_table.clear();
num_of_lives = 0; num_of_lives = 0;
max_value_number = -1; max_value_number = -1;
required_bytes = 0; required_bytes = 0;
} }
bool allocate(T_live_interval*); bool allocate(interval_ptr);
void addConflicts(std::set<int>& live_set, int val) void addConflicts(std::set<int>& live_set, int val)
{ {
for (auto iter = live_set.begin(), end = live_set.end(); iter != end; ++ iter) { for(auto iter = live_set.begin(), end = live_set.end(); iter != end; ++iter)
{
conflict_table[*iter].insert(val); conflict_table[*iter].insert(val);
conflict_table[val].insert(*iter); conflict_table[val].insert(*iter);
} }
...@@ -63,78 +81,87 @@ struct memory_coloring_impl { ...@@ -63,78 +81,87 @@ struct memory_coloring_impl {
void build(); void build();
void run(); void run();
void rewrite(); void rewrite();
private: private:
bool isParam(const instruction_ref ins) { return ins->op.name() == "@param"; } bool isParam(const instruction_ref ins) { return ins->op.name() == "@param"; }
bool isOutputParam(const instruction_ref ins) bool isOutputParam(const instruction_ref ins)
{ {
return isParam(ins) && any_cast<builtin::param>(ins->op).parameter == "output"; return isParam(ins) && any_cast<builtin::param>(ins->op).parameter == "output";
} }
bool isScratchParam(const instruction_ref ins) bool isScratchParam(const instruction_ref ins)
{ {
return isParam(ins) && any_cast<builtin::param>(ins->op).parameter == "scratch"; return isParam(ins) && any_cast<builtin::param>(ins->op).parameter == "scratch";
} }
bool isAllocate(const instruction_ref ins) { return ins->op.name() == "hip::allocate"; } bool isAllocate(const instruction_ref ins) { return ins->op.name() == "hip::allocate"; }
bool isOutline(const instruction_ref ins) { return ins->op.name() == "@outline"; } bool isOutline(const instruction_ref ins) { return ins->op.name() == "@outline"; }
bool isLiteral(const instruction_ref ins) { return ins->op.name() == "@literal"; } bool isLiteral(const instruction_ref ins) { return ins->op.name() == "@literal"; }
bool isCheckContext(const instruction_ref ins) { return ins->op.name() == "check_context"; } bool isCheckContext(const instruction_ref ins) { return ins->op.name() == "check_context"; }
bool isTranspose(const instruction_ref ins) { return ins->op.name() == "transpose"; } bool isTranspose(const instruction_ref ins) { return ins->op.name() == "transpose"; }
int getInputTieNdx(const instruction_ref ins) { int getInputTieNdx(const instruction_ref ins)
if (isTranspose(ins)) {
if(isTranspose(ins))
return 0; return 0;
int cnt = -1; int cnt = -1;
int last_allocate = -1; int last_allocate = -1;
for (auto&& arg : ins->arguments) { for(auto&& arg : ins->arguments)
{
cnt++; cnt++;
if (isAllocate(arg)) if(isAllocate(arg))
last_allocate = cnt; last_allocate = cnt;
} }
return last_allocate; return last_allocate;
} }
bool isDisjoin(T_live_range& range1, T_live_range& range2) { bool isDisjoin(T_live_range& range1, T_live_range& range2)
{
long long end1 = range1.offset + range1.size - 1; long long end1 = range1.offset + range1.size - 1;
long long end2 = range2.offset + range2.size - 1; long long end2 = range2.offset + range2.size - 1;
return ((end1 < range2.offset) || (end2 < range1.offset)); return ((end1 < range2.offset) || (end2 < range1.offset));
} }
#ifdef DEBUG_OPT #ifdef DEBUG_OPT
void dump(std::string); void dump(std::string);
void dump_program(); void dump_program();
void dump_intervals(); void dump_intervals();
void verify(); void verify();
#endif #endif
struct ordering { struct ordering
bool operator() (const T_live_interval* I1, const T_live_interval* I2) const {
bool operator()(const interval_ptr I1, const interval_ptr I2) const
{ {
int len1 = I1->get_end() - I1->get_begin(); int len1 = I1->get_end() - I1->get_begin();
int len2 = I2->get_end() - I2->get_begin(); int len2 = I2->get_end() - I2->get_begin();
if (len1 != len2) { if(len1 != len2)
{
return (len1 < len2) ? true : false; return (len1 < len2) ? true : false;
} else if (I1->result.bytes() != I2->result.bytes()) { }
else if(I1->result.bytes() != I2->result.bytes())
{
return (I1->result.bytes() < I2->result.bytes()) ? true : false; return (I1->result.bytes() < I2->result.bytes()) ? true : false;
} else { }
else
{
return I1->id > I2->id; return I1->id > I2->id;
} }
} }
bool operator() (const T_live_range* I1, const T_live_range* I2) const bool operator()(const T_live_range* I1, const T_live_range* I2) const
{ {
return (I1->offset > I2->offset); return (I1->offset > I2->offset);
} }
}; };
program* p_program; program* p_program;
std::unordered_map<const instruction*, T_live_interval*> instr2Live; std::unordered_map<const instruction*, interval_ptr> instr2Live;
// Map live interval Id to live interval. // Map live interval Id to live interval.
std::unordered_map<int, T_live_interval*> live_intervals; std::unordered_map<int, interval_ptr> live_intervals;
// Map live range value number to live range. // Map live range value number to live range.
std::unordered_map<int, T_live_range*> live_ranges; std::unordered_map<int, T_live_range*> live_ranges;
// Map live range value number to a set of conflicting live ranges' value numbers. // Map live range value number to a set of conflicting live ranges' value numbers.
std::unordered_map<int, std::set<int>> conflict_table; std::unordered_map<int, std::set<int>> conflict_table;
// Priority queue for coloring. // Priority queue for coloring.
std::priority_queue<T_live_interval*, std::vector<T_live_interval*>, ordering> alloc_queue; std::priority_queue<interval_ptr, std::vector<interval_ptr>, ordering> alloc_queue;
int num_of_lives; int num_of_lives;
int max_value_number; int max_value_number;
long long required_bytes; long long required_bytes;
}; };
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -94,7 +94,5 @@ void copy_to_gpu(char* dst, const char* src, std::size_t size) ...@@ -94,7 +94,5 @@ void copy_to_gpu(char* dst, const char* src, std::size_t size)
{ {
hipMemcpy(dst, src, size, hipMemcpyHostToDevice); hipMemcpy(dst, src, size, hipMemcpyHostToDevice);
} }
} // namespace gpu } // namespace gpu
} // namespace migraph } // namespace migraph
...@@ -85,22 +85,18 @@ struct hip_write ...@@ -85,22 +85,18 @@ struct hip_write
struct hip_memcpy struct hip_memcpy
{ {
std::string name() const { return "hip_memcpy"; } std::string name() const { return "hip_memcpy"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const { return inputs.at(2); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
return inputs.at(2); std::size_t* p_data = reinterpret_cast<std::size_t*>(args.at(1).data());
} char* dst = args.at(0).data() + p_data[0];
argument compute(context&, shape output_shape, std::vector<argument> args) const { const char* src = args.at(2).data();
std::size_t * p_data = reinterpret_cast<std::size_t*>(args.at(1).data()); std::size_t size = args.at(2).get_shape().bytes();
char* dst = args.at(0).data() + p_data[0];
const char* src = args.at(2).data();
std::size_t size = args.at(2).get_shape().bytes();
copy_to_gpu(dst, src, size); copy_to_gpu(dst, src, size);
return {output_shape, dst}; return {output_shape, dst};
} }
}; };
} // namespace gpu } // namespace gpu
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -37,13 +37,12 @@ void write_literals::apply(program& p) const ...@@ -37,13 +37,12 @@ void write_literals::apply(program& p) const
p.replace_instruction(ins, hip_load_literal{a.get_shape(), n}); p.replace_instruction(ins, hip_load_literal{a.get_shape(), n});
} }
#else #else
if (ins->op.name() == "write_literal") { if(ins->op.name() == "write_literal")
{
p.replace_instruction(ins, hip_memcpy{}, ins->arguments); p.replace_instruction(ins, hip_memcpy{}, ins->arguments);
} }
#endif #endif
} }
} }
} // namespace gpu } // namespace gpu
} // namespace migraph } // namespace migraph
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment