Commit 371a0f29 authored by mei-ye's avatar mei-ye
Browse files

more coding conventions fix

parent 7fa4d978
...@@ -538,7 +538,7 @@ struct get_mem_ptr ...@@ -538,7 +538,7 @@ struct get_mem_ptr
{ {
std::string name() const { return "get_mem_ptr:" + std::to_string(offset); } std::string name() const { return "get_mem_ptr:" + std::to_string(offset); }
shape compute_shape(std::vector<shape> inputs) const { return inputs.at(1); } shape compute_shape(std::vector<shape> inputs) const { return inputs.at(1); }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, const std::vector<argument>& args) const
{ {
return {std::move(output_shape), args.at(0).data() + offset}; return {std::move(output_shape), args.at(0).data() + offset};
} }
...@@ -549,7 +549,7 @@ struct write_literal ...@@ -549,7 +549,7 @@ struct write_literal
{ {
std::string name() const { return "write_literal"; } std::string name() const { return "write_literal"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.at(2); } shape compute_shape(std::vector<shape> inputs) const { return inputs.at(2); }
argument compute(context&, shape, std::vector<argument>) const argument compute(context&, const shape&, const std::vector<argument>&) const
{ {
MIGRAPH_THROW("not computable"); MIGRAPH_THROW("not computable");
} }
......
...@@ -93,8 +93,6 @@ struct program ...@@ -93,8 +93,6 @@ struct program
void perf_report(std::ostream& os, std::size_t n, parameter_map params) const; void perf_report(std::ostream& os, std::size_t n, parameter_map params) const;
int get_size() const;
friend std::ostream& operator<<(std::ostream& os, const program& p); friend std::ostream& operator<<(std::ostream& os, const program& p);
friend bool operator==(const program& x, const program& y); friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); } friend bool operator!=(const program& x, const program& y) { return !(x == y); }
......
...@@ -11,12 +11,12 @@ ...@@ -11,12 +11,12 @@
#include <vector> #include <vector>
#include <queue> #include <queue>
//#define DEBUG_OPT //#define MIGRAPH_DEBUG_OPT
#ifdef DEBUG_OPT #ifdef MIGRAPH_DEBUG_OPT
#define DEBUG(s) s #define MIGRAPH_DEBUG(s) s
#else #else
#define DEBUG(s) #define MIGRAPH_DEBUG(s)
#endif // DEBUG_OPT #endif // MIGRAPH_DEBUG_OPT
#endif // MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP #endif // MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "memory_coloring_impl.hpp" #include "memory_coloring_impl.hpp"
namespace migraph { namespace migraph {
void memory_coloring::apply(program& p) const void memory_coloring::apply(program& p) const
{ {
memory_coloring_impl opt(&p); memory_coloring_impl opt(&p);
......
#include "memory_coloring_impl.hpp" #include "memory_coloring_impl.hpp"
namespace migraph { namespace migraph {
void memory_coloring_impl::run() void memory_coloring_impl::run()
{ {
build(); build();
if(num_of_lives != 0) if(num_of_lives != 0)
{ {
DEBUG(dump("---Before memory coloring---")); MIGRAPH_DEBUG(dump("---Before memory coloring---"));
DEBUG(dump_program()); MIGRAPH_DEBUG(dump_program());
DEBUG(dump_intervals()); MIGRAPH_DEBUG(dump_intervals());
// Coloring // Coloring
while(!alloc_queue.empty()) while(!alloc_queue.empty())
{ {
...@@ -17,7 +18,7 @@ void memory_coloring_impl::run() ...@@ -17,7 +18,7 @@ void memory_coloring_impl::run()
alloc_queue.pop(); alloc_queue.pop();
} }
rewrite(); rewrite();
DEBUG(verify()); MIGRAPH_DEBUG(verify());
} }
} }
...@@ -77,101 +78,94 @@ bool memory_coloring_impl::allocate(interval_ptr interval) ...@@ -77,101 +78,94 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
conflict_queue.pop(); conflict_queue.pop();
} }
segment.offset = offset; segment.offset = offset;
DEBUG(segment.dump()); MIGRAPH_DEBUG(segment.dump());
required_bytes = std::max(required_bytes, offset + segment.size); required_bytes = std::max(required_bytes, offset + segment.size);
return true; return true;
} }
void memory_coloring_impl::build() void memory_coloring_impl::build()
{ {
int num_of_instrs = p_program->get_size(); std::size_t num_of_instrs = p_program->size();
if(num_of_instrs > 0) if(num_of_instrs == 0)
return;
int cur_points = num_of_instrs * 2;
instruction_ref iter = std::prev(p_program->end());
instruction_ref begin = p_program->begin();
std::vector<instruction_ref> dead_instrs;
std::set<int> live_set;
// Build live intervals.
live_intervals.resize(num_of_instrs);
do
{ {
int cur_points = num_of_instrs * 2; const instruction* p_iter = &(*iter);
instruction_ref iter = std::prev(p_program->end()); interval_ptr def_interval = nullptr;
instruction_ref begin = p_program->begin(); bool is_dead = false;
std::vector<instruction_ref> dead_instrs; if(instr2_live.find(p_iter) != instr2_live.end())
std::set<int> live_set;
// Build live intervals.
live_intervals.resize(num_of_instrs);
do
{ {
const instruction* p_iter = &(*iter); def_interval = instr2_live[p_iter];
interval_ptr def_interval = nullptr; bool is_lit = is_literal(iter);
bool is_dead = false; if(is_allocate(iter) || is_lit)
if(instr2_live.find(p_iter) != instr2_live.end())
{ {
def_interval = instr2_live[p_iter]; live_range& range = def_interval->segment;
bool is_lit = is_literal(iter); def_interval->result = iter->result;
if(is_allocate(iter) || is_lit) def_interval->is_literal = is_lit;
{ alloc_queue.push(def_interval);
live_range& range = def_interval->segment; range.begin = cur_points;
def_interval->result = iter->result; range.size = (iter->result).bytes();
def_interval->is_literal = is_lit; live_set.erase(range.vn);
alloc_queue.push(def_interval);
range.begin = cur_points;
range.size = (iter->result).bytes();
live_set.erase(range.vn);
}
} }
else if(!is_param(iter) && !is_outline(iter) && !is_check_context(iter)) }
else if(!is_param(iter) && !is_outline(iter) && !is_check_context(iter))
{
is_dead = true;
}
int tie_ndx = get_input_tie_ndx(iter);
int cnt = -1;
for(auto&& arg : iter->arguments)
{
cnt++;
if(is_param(arg) || is_outline(arg))
{ {
is_dead = true; if(is_output_param(arg))
is_dead = false;
continue;
} }
int tie_ndx = get_input_tie_ndx(iter); const instruction* p_arg = &(*arg);
if(!iter->arguments.empty()) if(cnt == tie_ndx)
{ {
int cnt = -1; // input memory is used as this instruction's output.
for(auto&& arg : iter->arguments) // def is considered as use. Coalesce the live intervals.
{ assert(def_interval != nullptr);
cnt++; def_interval->add_use(cur_points);
if(is_param(arg) || is_outline(arg)) instr2_live[p_arg] = def_interval;
{
if(is_output_param(arg))
is_dead = false;
continue;
}
const instruction* p_arg = &(*arg);
if(cnt == tie_ndx)
{
// input memory is used as this instruction's output.
// def is considered as use. Coalesce the live intervals.
#ifndef NDEBUG
assert(def_interval != nullptr);
#endif
def_interval->add_use(cur_points);
instr2_live[p_arg] = def_interval;
}
else if(instr2_live.find(p_arg) == instr2_live.end())
{
// First time see a use, create a live interval.
int id = num_of_lives++;
interval_ptr interval = &(live_intervals[id]);
interval->id = id;
interval->segment.end = cur_points;
interval->segment.vn = ++max_value_number;
interval->add_use(cur_points);
instr2_live[p_arg] = interval;
add_conflicts(live_set, max_value_number);
live_set.insert(max_value_number);
live_ranges[max_value_number] = &(interval->segment);
}
else
{
interval_ptr interval = instr2_live[p_arg];
interval->add_use(cur_points);
#ifndef NDEBUG
assert(live_set.find(interval->id) != live_set.end());
#endif
}
}
} }
if(is_dead) else if(instr2_live.find(p_arg) == instr2_live.end())
dead_instrs.push_back(iter); {
cur_points -= 2; // First time see a use, create a live interval.
iter = std::prev(iter); int id = num_of_lives++;
} while(iter != begin); interval_ptr interval = &(live_intervals[id]);
} interval->id = id;
interval->segment.end = cur_points;
interval->segment.vn = ++max_value_number;
interval->add_use(cur_points);
instr2_live[p_arg] = interval;
add_conflicts(live_set, max_value_number);
live_set.insert(max_value_number);
live_ranges[max_value_number] = &(interval->segment);
}
else
{
interval_ptr interval = instr2_live[p_arg];
interval->add_use(cur_points);
assert(live_set.find(interval->id) != live_set.end());
}
}
if(is_dead)
dead_instrs.push_back(iter);
cur_points -= 2;
iter = std::prev(iter);
} while(iter != begin);
} }
void memory_coloring_impl::rewrite() void memory_coloring_impl::rewrite()
...@@ -190,9 +184,7 @@ void memory_coloring_impl::rewrite() ...@@ -190,9 +184,7 @@ void memory_coloring_impl::rewrite()
interval_ptr interval = instr2_live[p_iter]; interval_ptr interval = instr2_live[p_iter];
if(interval->get_offset() == InvalidOffset) if(interval->get_offset() == InvalidOffset)
{ {
#ifndef NDEBUG
assert((interval->get_begin() == InvalidOffset) || interval->result.bytes() == 0); assert((interval->get_begin() == InvalidOffset) || interval->result.bytes() == 0);
#endif
continue; continue;
} }
std::size_t offset = interval->get_offset(); std::size_t offset = interval->get_offset();
...@@ -209,11 +201,11 @@ void memory_coloring_impl::rewrite() ...@@ -209,11 +201,11 @@ void memory_coloring_impl::rewrite()
} }
} }
} }
DEBUG(dump("---After rewrite---")); MIGRAPH_DEBUG(dump("---After rewrite---"));
DEBUG(dump_program()); MIGRAPH_DEBUG(dump_program());
} }
#ifdef DEBUG_OPT #ifdef MIGRAPH_DEBUG_OPT
void memory_coloring_impl::dump(const std::string str) { std::cout << str << std::endl; } void memory_coloring_impl::dump(const std::string str) { std::cout << str << std::endl; }
void memory_coloring_impl::dump_program() { std::cout << *p_program << std::endl; } void memory_coloring_impl::dump_program() { std::cout << *p_program << std::endl; }
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
namespace migraph { namespace migraph {
#define InvalidOffset (-1) static const int InvalidOffset = -1;
struct live_range struct live_range
{ {
...@@ -13,7 +13,7 @@ struct live_range ...@@ -13,7 +13,7 @@ struct live_range
long long offset; // offset to base pointer of allocated memory trunk. long long offset; // offset to base pointer of allocated memory trunk.
int vn; // value number that identifies this live_range. int vn; // value number that identifies this live_range.
long long size; // size of required memory in bytes long long size; // size of required memory in bytes
#ifdef DEBUG_OPT #ifdef MIGRAPH_DEBUG_OPT
void dump(); void dump();
#endif #endif
}; };
...@@ -31,7 +31,7 @@ struct live_interval ...@@ -31,7 +31,7 @@ struct live_interval
int get_end() const { return segment.end; } int get_end() const { return segment.end; }
long long get_offset() const { return segment.offset; } long long get_offset() const { return segment.offset; }
#ifdef DEBUG_OPT #ifdef MIGRAPH_DEBUG_OPT
void dump(); void dump();
#endif #endif
...@@ -42,12 +42,11 @@ struct live_interval ...@@ -42,12 +42,11 @@ struct live_interval
bool is_literal; bool is_literal;
}; };
#define interval_ptr live_interval* typedef live_interval* interval_ptr;
struct memory_coloring_impl struct memory_coloring_impl
{ {
memory_coloring_impl(program* p) : p_program(p) { init(); } memory_coloring_impl(program* p) : p_program(p)
void init()
{ {
instr2_live.clear(); instr2_live.clear();
live_ranges.clear(); live_ranges.clear();
...@@ -70,16 +69,19 @@ struct memory_coloring_impl ...@@ -70,16 +69,19 @@ struct memory_coloring_impl
void rewrite(); void rewrite();
private: private:
bool is_param(const instruction_ref ins) { return ins->op.name() == "@param"; } static bool is_param(const instruction_ref ins) { return ins->op.name() == "@param"; }
bool is_output_param(const instruction_ref ins) static bool is_output_param(const instruction_ref ins)
{ {
return is_param(ins) && any_cast<builtin::param>(ins->op).parameter == "output"; return is_param(ins) && any_cast<builtin::param>(ins->op).parameter == "output";
} }
bool is_allocate(const instruction_ref ins) { return ins->op.name() == "hip::allocate"; } static bool is_allocate(const instruction_ref ins) { return ins->op.name() == "hip::allocate"; }
bool is_outline(const instruction_ref ins) { return ins->op.name() == "@outline"; } static bool is_outline(const instruction_ref ins) { return ins->op.name() == "@outline"; }
bool is_literal(const instruction_ref ins) { return ins->op.name() == "@literal"; } static bool is_literal(const instruction_ref ins) { return ins->op.name() == "@literal"; }
bool is_check_context(const instruction_ref ins) { return ins->op.name() == "check_context"; } static bool is_check_context(const instruction_ref ins)
bool is_transpose(const instruction_ref ins) { return ins->op.name() == "transpose"; } {
return ins->op.name() == "check_context";
}
static bool is_transpose(const instruction_ref ins) { return ins->op.name() == "transpose"; }
int get_input_tie_ndx(const instruction_ref ins) int get_input_tie_ndx(const instruction_ref ins)
{ {
if(is_transpose(ins)) if(is_transpose(ins))
...@@ -94,8 +96,8 @@ struct memory_coloring_impl ...@@ -94,8 +96,8 @@ struct memory_coloring_impl
} }
return last_allocate; return last_allocate;
} }
#ifdef DEBUG_OPT #ifdef MIGRAPH_DEBUG_OPT
bool is_disjoin(live_range& range1, live_range& range2) static bool is_disjoin(live_range& range1, live_range& range2)
{ {
long long end1 = range1.offset + range1.size - 1; long long end1 = range1.offset + range1.size - 1;
long long end2 = range2.offset + range2.size - 1; long long end2 = range2.offset + range2.size - 1;
......
...@@ -320,10 +320,6 @@ argument program::eval(std::unordered_map<std::string, argument> params) const ...@@ -320,10 +320,6 @@ argument program::eval(std::unordered_map<std::string, argument> params) const
return generic_eval( return generic_eval(
*this, this->impl->ctx, std::move(params), [](auto&, auto f) { return f(); }); *this, this->impl->ctx, std::move(params), [](auto&, auto f) { return f(); });
} }
int program::get_size() const
{
return (*impl).instructions.size();
}
double common_average(const std::vector<double>& v) double common_average(const std::vector<double>& v)
{ {
...@@ -428,5 +424,4 @@ std::ostream& operator<<(std::ostream& os, const program& p) ...@@ -428,5 +424,4 @@ std::ostream& operator<<(std::ostream& os, const program& p)
print_program(os, p, [](auto&&...) {}); print_program(os, p, [](auto&&...) {});
return os; return os;
} }
} // namespace migraph } // namespace migraph
...@@ -58,7 +58,9 @@ std::vector<T> read_from_gpu(const void* x, std::size_t sz) ...@@ -58,7 +58,9 @@ std::vector<T> read_from_gpu(const void* x, std::size_t sz)
hip_ptr write_to_gpu(const void* x, std::size_t sz, bool host = false) hip_ptr write_to_gpu(const void* x, std::size_t sz, bool host = false)
{ {
auto result = allocate_gpu(sz, host); auto result = allocate_gpu(sz, host);
// gpu_sync();
auto status = hipMemcpy(result.get(), x, sz, hipMemcpyHostToDevice); auto status = hipMemcpy(result.get(), x, sz, hipMemcpyHostToDevice);
if(status != hipSuccess) if(status != hipSuccess)
MIGRAPH_THROW("Copy to gpu failed: " + hip_error(status)); MIGRAPH_THROW("Copy to gpu failed: " + hip_error(status));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment