"examples/research_projects/diffusion_dpo/REAMDE.md" did not exist on "ae060fc4f1ba8b9b9a7de35888138415808bfcd6"
Commit 70837f1a authored by mei-ye's avatar mei-ye
Browse files

add memory coloring

parent 686b9ea9
...@@ -550,6 +550,31 @@ struct div : binary ...@@ -550,6 +550,31 @@ struct div : binary
std::string name() const { return "div"; } std::string name() const { return "div"; }
}; };
struct get_mem_ptr
{
std::string name() const { return "get_mem_ptr:" + std::to_string(offset); }
shape compute_shape(std::vector<shape> inputs) const
{
return inputs.at(1);
}
argument compute(context&, shape output_shape, std::vector<argument> args) const {
return {output_shape, args.at(0).data() + offset};
}
std::size_t offset = 0;
};
struct write_literal
{
std::string name() const { return "write_literal"; }
shape compute_shape(std::vector<shape> inputs) const
{
return inputs.at(2);
}
argument compute(context&, shape output_shape, std::vector<argument> args) const {
assert(false);
}
};
struct outline struct outline
{ {
shape s; shape s;
......
...@@ -6,50 +6,68 @@ void memory_coloring_impl::run() ...@@ -6,50 +6,68 @@ void memory_coloring_impl::run()
build(); build();
if (num_of_lives != 0) { if (num_of_lives != 0) {
DEBUG(dump("---Before memory coloring---")); DEBUG(dump("---Before memory coloring---"));
DEBUG(dump()); DEBUG(dump(p_program));
// Coloring // Coloring
while (!alloc_queue.empty()) { while (!alloc_queue.empty()) {
T_live_interval* interval = alloc_queue.top(); T_live_interval* interval = alloc_queue.top();
allocate(interval); allocate(interval);
alloc_queue.pop(); alloc_queue.pop();
} }
for (int i = 0; i < num_of_lives; ++i) rewrite();
DEBUG(verify());
for (int i = 0; i < num_of_lives; ++i) {
free(live_intervals[i]); free(live_intervals[i]);
} }
}
} }
bool memory_coloring_impl::allocate(T_live_interval* interval) bool memory_coloring_impl::allocate(T_live_interval* interval)
{ {
shape s = interval->result; shape s = interval->result;
int size = s.bytes(); std::size_t size = s.bytes();
std::size_t element_size = size / s.elements(); std::size_t element_size = size / s.elements();
T_live_range& segment = interval->segment; T_live_range& segment = interval->segment;
int vn = segment.vn; int vn = segment.vn;
std::priority_queue<T_live_range*, std::vector<T_live_range*>, ordering> conflict_queue; std::priority_queue<T_live_range*, std::vector<T_live_range*>, ordering> conflict_queue;
std::unordered_map<long long, T_live_range*> offset2Live;
offset2Live.clear();
if (conflict_table.find(vn) != conflict_table.end()) { if (conflict_table.find(vn) != conflict_table.end()) {
std::set<int>& vn_set = conflict_table[vn]; std::set<int>& vn_set = conflict_table[vn];
for (auto iter = vn_set.begin(), end = vn_set.end(); iter != end; ++iter) { for (auto iter = vn_set.begin(), end = vn_set.end(); iter != end; ++iter) {
T_live_range* range = live_ranges[*iter]; T_live_range* range = live_ranges[*iter];
if (range->offset != -1) long long offset = range->offset;
if (offset != InvalidOffset) {
conflict_queue.push(range); conflict_queue.push(range);
if (offset2Live.find(offset) == offset2Live.end()) {
offset2Live[offset] = range;
} else {
T_live_range* prev = offset2Live[offset];
assert(prev->offset == offset);
if (prev->size < range->size)
offset2Live[offset] = range;
}
}
} }
} }
int offset = 0; long long offset = 0;
while (!conflict_queue.empty()) { while (!conflict_queue.empty()) {
T_live_range* range = conflict_queue.top(); T_live_range* range = conflict_queue.top();
int cur_offset = range->offset; long long cur_offset = range->offset;
if (offset2Live[cur_offset] == range) {
if ((cur_offset > offset) && (cur_offset - offset) >= size) { if ((cur_offset > offset) && (cur_offset - offset) >= size) {
break; break;
} }
offset = cur_offset + range->size; offset = cur_offset + range->size;
if ((offset % element_size) != 0) if ((offset % element_size) != 0)
offset += (element_size - (offset % element_size)); offset += (element_size - (offset % element_size));
}
conflict_queue.pop(); conflict_queue.pop();
} }
segment.offset = offset; segment.offset = offset;
DEBUG(segment.dump()); DEBUG(segment.dump());
required_bytes = std::max(required_bytes, offset + segment.size);
return true; return true;
} }
...@@ -62,9 +80,7 @@ void memory_coloring_impl::build() ...@@ -62,9 +80,7 @@ void memory_coloring_impl::build()
instruction_ref iter = std::prev(p_program->end()); instruction_ref iter = std::prev(p_program->end());
instruction_ref begin = p_program->begin(); instruction_ref begin = p_program->begin();
std::vector<instruction_ref> dead_instrs; std::vector<instruction_ref> dead_instrs;
std::unordered_map<const instruction*, T_live_interval*> instr2Live;
std::set<int> live_set; std::set<int> live_set;
T_live_interval* next_def = nullptr;
// Build live intervals. // Build live intervals.
do { do {
const instruction* p_iter = &(*iter); const instruction* p_iter = &(*iter);
...@@ -80,23 +96,25 @@ void memory_coloring_impl::build() ...@@ -80,23 +96,25 @@ void memory_coloring_impl::build()
alloc_queue.push(def_interval); alloc_queue.push(def_interval);
range.begin = cur_points; range.begin = cur_points;
range.size = (iter->result).bytes(); range.size = (iter->result).bytes();
next_def = def_interval;
live_set.erase(range.vn); live_set.erase(range.vn);
} }
} else if (!isParam(iter) && !isOutline(iter) && !isCheckContext(iter)) { } else if (!isParam(iter) && !isOutline(iter) && !isCheckContext(iter)) {
isDead = true; isDead = true;
} }
int tieNdx = getInputTieNdx(iter);
if (!iter->arguments.empty()) { if (!iter->arguments.empty()) {
int cnt = -1;
for (auto&& arg : iter->arguments) { for (auto&& arg : iter->arguments) {
cnt++;
if (isParam(arg) || isOutline(arg)) { if (isParam(arg) || isOutline(arg)) {
if (isOutputParam(arg)) if (isOutputParam(arg))
isDead = false; isDead = false;
continue; continue;
} }
const instruction* p_arg = &(*arg); const instruction* p_arg = &(*arg);
if (isAllocate(arg)) { if (cnt == tieNdx) {
// input is from hip::allocate, def is considered as use // input memory is used as this instruction's output.
// and coalesce the live intervals. // def is considered as use. Coalesce the live intervals.
def_interval->addUse(cur_points); def_interval->addUse(cur_points);
instr2Live[p_arg] = def_interval; instr2Live[p_arg] = def_interval;
} else if (instr2Live.find(p_arg) == instr2Live.end()) { } else if (instr2Live.find(p_arg) == instr2Live.end()) {
...@@ -112,10 +130,6 @@ void memory_coloring_impl::build() ...@@ -112,10 +130,6 @@ void memory_coloring_impl::build()
live_set.insert(max_value_number); live_set.insert(max_value_number);
live_intervals[id] = interval; live_intervals[id] = interval;
live_ranges[max_value_number] = &(interval->segment); live_ranges[max_value_number] = &(interval->segment);
// Keep track of live intervals that are inactive when
// next_def is enqueued.
if (next_def != nullptr)
next_def->inactive_afters.push_back(interval);
} else { } else {
T_live_interval* interval = instr2Live[p_arg]; T_live_interval* interval = instr2Live[p_arg];
interval->addUse(cur_points); interval->addUse(cur_points);
...@@ -130,10 +144,57 @@ void memory_coloring_impl::build() ...@@ -130,10 +144,57 @@ void memory_coloring_impl::build()
} while (iter != begin); } while (iter != begin);
} }
void memory_coloring_impl::rewrite()
{
instruction_ref end = p_program->end();
instruction_ref scratch_param = end;
for (auto ins : iterator_for(*p_program)) {
const instruction* p_iter = &(*ins);
if (isScratchParam(ins)) {
scratch_param = ins;
int allocated_bytes = ins->result.bytes();
if (allocated_bytes < required_bytes) {
std::cout << "required bytes: " << required_bytes << "allocated bytes: " << allocated_bytes << std::endl;
throw std::runtime_error("insufficent memory for MIGraph");
}
#ifdef DEBUG_OPT
float frac = 1.0 * required_bytes/allocated_bytes*100;
std::cout << "memory usage percentage: " << to_string(frac) << "%" << std::endl;
#endif
}
if (instr2Live.find(p_iter) != instr2Live.end()) {
T_live_interval* interval = instr2Live[p_iter];
if (interval->get_offset() == InvalidOffset) {
DEBUG(assert(interval->get_begin() == InvalidOffset));
continue;
}
std::size_t offset = interval->get_offset();
if (isAllocate(ins)) {
if (scratch_param == end)
throw std::runtime_error("missing scratch parameter");
p_program->replace_instruction(ins, get_mem_ptr{offset}, scratch_param, ins->arguments.at(0));
} else if (isLiteral(ins)) {
if (scratch_param == end)
throw std::runtime_error("missing scratch parameter");
auto pre = p_program->add_literal(ins->lit);
auto index = p_program->add_literal(offset);
p_program->replace_instruction(ins, write_literal{}, scratch_param, index, pre);
}
}
}
DEBUG(dump("---After rewrite---"));
DEBUG(dump(p_program));
}
#ifdef DEBUG_OPT #ifdef DEBUG_OPT
void memory_coloring_impl::dump(std::string str) void memory_coloring_impl::dump(std::string str)
{ {
std::cout << str << std::endl; std::cout << str << std::endl;
}
void memory_coloring_impl::dump(program* p_program)
{
std::cout << *p_program << std::endl; std::cout << *p_program << std::endl;
} }
...@@ -145,7 +206,7 @@ void memory_coloring_impl::dump() ...@@ -145,7 +206,7 @@ void memory_coloring_impl::dump()
T_live_interval* interval = live_intervals[i]; T_live_interval* interval = live_intervals[i];
interval->dump(); interval->dump();
} }
std::cout << "conflict table:" << std::endl; std::cout << "---conflict table---" << std::endl;
for (int i = 0; i <= max_value_number; ++i) { for (int i = 0; i <= max_value_number; ++i) {
std::cout << " segment:" << i; std::cout << " segment:" << i;
std::cout << " =>"; std::cout << " =>";
...@@ -158,15 +219,38 @@ void memory_coloring_impl::dump() ...@@ -158,15 +219,38 @@ void memory_coloring_impl::dump()
} }
} }
void memory_coloring_impl::verify()
{
if (num_of_lives > 0) {
for (int i = 0; i < num_of_lives; ++i) {
T_live_interval* interval = live_intervals[i];
T_live_range& segment = interval->segment;
if (segment.offset == InvalidOffset)
continue;
int vn = segment.vn;
if (conflict_table.find(vn) != conflict_table.end()) {
std::set<int>& vn_set = conflict_table[vn];
for (auto iter = vn_set.begin(), end = vn_set.end(); iter != end; ++iter) {
T_live_range* range = live_ranges[*iter];
if (range->offset == InvalidOffset)
continue;
if (!isDisjoin(*range, segment))
assert(false);
}
}
}
}
}
#define GET_INS_ENUM(x) (((x) >> 1) - 1) #define GET_INS_ENUM(x) (((x) >> 1) - 1)
void live_range::dump() void live_range::dump()
{ {
std::cout << " segment:" << vn; std::cout << " segment:" << vn;
std::cout << " [" << GET_INS_ENUM(begin) << ", " << GET_INS_ENUM(end) << "]"; std::cout << " [" << GET_INS_ENUM(begin) << ", " << GET_INS_ENUM(end) << "]";
if (offset != -1) { if (offset != InvalidOffset) {
std::cout << " mem:"; std::cout << " mem:";
std::cout << " [" << offset << "," << offset + size << "]"; std::cout << " [" << offset << "," << offset + size - 1 << "]";
} }
std::cout << std::endl; std::cout << std::endl;
} }
...@@ -181,13 +265,6 @@ void live_interval::dump() ...@@ -181,13 +265,6 @@ void live_interval::dump()
std::cout << " " << GET_INS_ENUM(use) << ","; std::cout << " " << GET_INS_ENUM(use) << ",";
} }
if (!inactive_afters.empty()) {
std::cout << " inactivate:";
for (auto iter = inactive_afters.begin(), end = inactive_afters.end(); iter != end; ++iter) {
T_live_interval*& interval = *iter;
std::cout << " " << interval->id << ",";
}
}
if (isLiteral) if (isLiteral)
std::cout << " literal"; std::cout << " literal";
std::cout << " " << result; std::cout << " " << result;
......
...@@ -4,12 +4,14 @@ ...@@ -4,12 +4,14 @@
namespace migraph { namespace migraph {
#define InvalidOffset -1
typedef struct live_range { typedef struct live_range {
int begin; // begin point in the instruction stream. int begin; // begin point in the instruction stream.
int end; // end point in the instruction stream. int end; // end point in the instruction stream.
int offset; // offset to base pointer of allocated memory trunk. long long offset; // offset to base pointer of allocated memory trunk.
int vn; // value number that identifies this live_range. int vn; // value number that identifies this live_range.
int size; // size of required memory in bytes long long size; // size of required memory in bytes
#ifdef DEBUG_OPT #ifdef DEBUG_OPT
void dump(); void dump();
#endif #endif
...@@ -20,11 +22,12 @@ typedef struct live_interval { ...@@ -20,11 +22,12 @@ typedef struct live_interval {
void init() { void init() {
id = -1; isLiteral = false; id = -1; isLiteral = false;
segment = { -1, -1, -1, -1, 0}; segment = { -1, -1, InvalidOffset, -1, 0};
} }
void addUse(int use) { use_points.push_front(use); } void addUse(int use) { use_points.push_front(use); }
int get_begin() const { return segment.begin; } int get_begin() const { return segment.begin; }
int get_end() const { return segment.end; } int get_end() const { return segment.end; }
long long get_offset() const { return segment.offset; }
#ifdef DEBUG_OPT #ifdef DEBUG_OPT
void dump(); void dump();
...@@ -33,9 +36,6 @@ typedef struct live_interval { ...@@ -33,9 +36,6 @@ typedef struct live_interval {
T_live_range segment; T_live_range segment;
int id; int id;
std::list<int> use_points; std::list<int> use_points;
// Live intervals that are inactive when this live interval is enqueued.
// can be used for live interval collapsing.
std::list<struct live_interval*> inactive_afters;
shape result; shape result;
bool isLiteral; bool isLiteral;
...@@ -44,11 +44,13 @@ typedef struct live_interval { ...@@ -44,11 +44,13 @@ typedef struct live_interval {
struct memory_coloring_impl { struct memory_coloring_impl {
explicit memory_coloring_impl(program *p) : p_program(p) explicit memory_coloring_impl(program *p) : p_program(p)
{ {
instr2Live.clear();
live_intervals.clear(); live_intervals.clear();
live_ranges.clear(); live_ranges.clear();
conflict_table.clear(); conflict_table.clear();
num_of_lives = 0; num_of_lives = 0;
max_value_number = -1; max_value_number = -1;
required_bytes = 0;
} }
bool allocate(T_live_interval*); bool allocate(T_live_interval*);
void addConflicts(std::set<int>& live_set, int val) void addConflicts(std::set<int>& live_set, int val)
...@@ -60,20 +62,43 @@ struct memory_coloring_impl { ...@@ -60,20 +62,43 @@ struct memory_coloring_impl {
} }
void build(); void build();
void run(); void run();
void rewrite();
private: private:
bool isParam(const instruction_ref ins) { return ins->op.name() == "@param"; } bool isParam(const instruction_ref ins) { return ins->op.name() == "@param"; }
bool isOutputParam(const instruction_ref ins) bool isOutputParam(const instruction_ref ins)
{ {
return isParam(ins) && any_cast<builtin::param>(ins->op).parameter == "output"; return isParam(ins) && any_cast<builtin::param>(ins->op).parameter == "output";
} }
bool isScratchParam(const instruction_ref ins)
{
return isParam(ins) && any_cast<builtin::param>(ins->op).parameter == "scratch";
}
bool isAllocate(const instruction_ref ins) { return ins->op.name() == "hip::allocate"; } bool isAllocate(const instruction_ref ins) { return ins->op.name() == "hip::allocate"; }
bool isOutline(const instruction_ref ins) { return ins->op.name() == "@outline"; } bool isOutline(const instruction_ref ins) { return ins->op.name() == "@outline"; }
bool isLiteral(const instruction_ref ins) { return ins->op.name() == "@literal"; } bool isLiteral(const instruction_ref ins) { return ins->op.name() == "@literal"; }
bool isCheckContext(const instruction_ref ins) { return ins->op.name() == "check_context"; } bool isCheckContext(const instruction_ref ins) { return ins->op.name() == "check_context"; }
bool isGPUTranspose(const instruction_ref ins) { return ins->op.name() == "gpu::transpose"; }
int getInputTieNdx(const instruction_ref ins) {
if (isGPUTranspose(ins))
return 0;
int cnt = -1;
for (auto&& arg : ins->arguments) {
cnt++;
if (isAllocate(arg))
return cnt;
}
return -1;
}
bool isDisjoin(T_live_range& range1, T_live_range& range2) {
long long end1 = range1.offset + range1.size - 1;
long long end2 = range2.offset + range2.size - 1;
return ((end1 < range2.offset) || (end2 < range1.offset));
}
#ifdef DEBUG_OPT #ifdef DEBUG_OPT
void dump(std::string); void dump(std::string);
void dump(); void dump(program*);
void verify();
#endif #endif
struct ordering { struct ordering {
bool operator() (const T_live_interval* I1, const T_live_interval* I2) const bool operator() (const T_live_interval* I1, const T_live_interval* I2) const
...@@ -94,6 +119,7 @@ struct memory_coloring_impl { ...@@ -94,6 +119,7 @@ struct memory_coloring_impl {
} }
}; };
program* p_program; program* p_program;
std::unordered_map<const instruction*, T_live_interval*> instr2Live;
// Map live interval Id to live interval. // Map live interval Id to live interval.
std::unordered_map<int, T_live_interval*> live_intervals; std::unordered_map<int, T_live_interval*> live_intervals;
// Map live range value number to live range. // Map live range value number to live range.
...@@ -105,6 +131,7 @@ struct memory_coloring_impl { ...@@ -105,6 +131,7 @@ struct memory_coloring_impl {
int num_of_lives; int num_of_lives;
int max_value_number; int max_value_number;
long long required_bytes;
}; };
} // namespace migraph } // namespace migraph
......
...@@ -18,6 +18,8 @@ hip_ptr allocate_gpu(std::size_t sz) ...@@ -18,6 +18,8 @@ hip_ptr allocate_gpu(std::size_t sz)
hipMalloc(&result, sz); hipMalloc(&result, sz);
if (result == nullptr) if (result == nullptr)
throw std::runtime_error("can not allocate GPU memory"); throw std::runtime_error("can not allocate GPU memory");
char * ptr = reinterpret_cast<char*>(result);
std::cout << "MIGraph allocated mem: [" << result << "," << ptr + sz -1 << "]" << std::endl;
return hip_ptr{result}; return hip_ptr{result};
} }
...@@ -69,6 +71,11 @@ migraph::argument from_gpu(migraph::argument arg) ...@@ -69,6 +71,11 @@ migraph::argument from_gpu(migraph::argument arg)
return result; return result;
} }
void copy_to_gpu(char* dst, const char* src, std::size_t size)
{
hipMemcpy(dst, src, size, hipMemcpyHostToDevice);
}
} // namespace gpu } // namespace gpu
} // namespace migraph } // namespace migraph
...@@ -12,6 +12,8 @@ migraph::argument to_gpu(migraph::argument arg); ...@@ -12,6 +12,8 @@ migraph::argument to_gpu(migraph::argument arg);
migraph::argument from_gpu(migraph::argument arg); migraph::argument from_gpu(migraph::argument arg);
void copy_to_gpu(char* dst, const char* src, std::size_t size);
struct hip_allocate struct hip_allocate
{ {
std::string name() const { return "hip::allocate"; } std::string name() const { return "hip::allocate"; }
...@@ -40,6 +42,23 @@ struct hip_write ...@@ -40,6 +42,23 @@ struct hip_write
} }
}; };
struct hip_memcpy
{
std::string name() const { return "hip_memcpy"; }
shape compute_shape(std::vector<shape> inputs) const
{
return inputs.at(2);
}
argument compute(context&, shape output_shape, std::vector<argument> args) const {
std::size_t * p_data = reinterpret_cast<std::size_t*>(args.at(1).data());
char* dst = args.at(0).data() + p_data[0];
const char* src = args.at(2).data();
std::size_t size = args.at(2).get_shape().bytes();
copy_to_gpu(dst, src, size);
return {output_shape, dst};
}
};
} // namespace gpu } // namespace gpu
} // namespace migraph } // namespace migraph
......
...@@ -11,12 +11,18 @@ void write_literals::apply(program& p) const ...@@ -11,12 +11,18 @@ void write_literals::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
#if 0
if(ins->op.name() == "@literal") if(ins->op.name() == "@literal")
{ {
literal l = ins->lit; literal l = ins->lit;
auto pre = p.add_literal(l); auto pre = p.add_literal(l);
p.replace_instruction(ins, hip_write{}, pre); p.replace_instruction(ins, hip_write{}, pre);
} }
#else
if (ins->op.name() == "write_literal") {
p.replace_instruction(ins, hip_memcpy{}, ins->arguments);
}
#endif
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment