Commit 371a0f29 authored by mei-ye's avatar mei-ye
Browse files

more coding conventions fix

parent 7fa4d978
......@@ -538,7 +538,7 @@ struct get_mem_ptr
{
std::string name() const { return "get_mem_ptr:" + std::to_string(offset); }
shape compute_shape(std::vector<shape> inputs) const { return inputs.at(1); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
argument compute(context&, const shape& output_shape, const std::vector<argument>& args) const
{
return {std::move(output_shape), args.at(0).data() + offset};
}
......@@ -549,7 +549,7 @@ struct write_literal
{
std::string name() const { return "write_literal"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.at(2); }
argument compute(context&, shape, std::vector<argument>) const
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
......
......@@ -93,8 +93,6 @@ struct program
void perf_report(std::ostream& os, std::size_t n, parameter_map params) const;
int get_size() const;
friend std::ostream& operator<<(std::ostream& os, const program& p);
friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); }
......
......@@ -11,12 +11,12 @@
#include <vector>
#include <queue>
//#define DEBUG_OPT
//#define MIGRAPH_DEBUG_OPT
#ifdef DEBUG_OPT
#define DEBUG(s) s
#ifdef MIGRAPH_DEBUG_OPT
#define MIGRAPH_DEBUG(s) s
#else
#define DEBUG(s)
#endif // DEBUG_OPT
#define MIGRAPH_DEBUG(s)
#endif // MIGRAPH_DEBUG_OPT
#endif // MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP
......@@ -2,6 +2,7 @@
#include "memory_coloring_impl.hpp"
namespace migraph {
void memory_coloring::apply(program& p) const
{
memory_coloring_impl opt(&p);
......
#include "memory_coloring_impl.hpp"
namespace migraph {
void memory_coloring_impl::run()
{
build();
if(num_of_lives != 0)
{
DEBUG(dump("---Before memory coloring---"));
DEBUG(dump_program());
DEBUG(dump_intervals());
MIGRAPH_DEBUG(dump("---Before memory coloring---"));
MIGRAPH_DEBUG(dump_program());
MIGRAPH_DEBUG(dump_intervals());
// Coloring
while(!alloc_queue.empty())
{
......@@ -17,7 +18,7 @@ void memory_coloring_impl::run()
alloc_queue.pop();
}
rewrite();
DEBUG(verify());
MIGRAPH_DEBUG(verify());
}
}
......@@ -77,101 +78,94 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
conflict_queue.pop();
}
segment.offset = offset;
DEBUG(segment.dump());
MIGRAPH_DEBUG(segment.dump());
required_bytes = std::max(required_bytes, offset + segment.size);
return true;
}
void memory_coloring_impl::build()
{
int num_of_instrs = p_program->get_size();
if(num_of_instrs > 0)
std::size_t num_of_instrs = p_program->size();
if(num_of_instrs == 0)
return;
int cur_points = num_of_instrs * 2;
instruction_ref iter = std::prev(p_program->end());
instruction_ref begin = p_program->begin();
std::vector<instruction_ref> dead_instrs;
std::set<int> live_set;
// Build live intervals.
live_intervals.resize(num_of_instrs);
do
{
int cur_points = num_of_instrs * 2;
instruction_ref iter = std::prev(p_program->end());
instruction_ref begin = p_program->begin();
std::vector<instruction_ref> dead_instrs;
std::set<int> live_set;
// Build live intervals.
live_intervals.resize(num_of_instrs);
do
const instruction* p_iter = &(*iter);
interval_ptr def_interval = nullptr;
bool is_dead = false;
if(instr2_live.find(p_iter) != instr2_live.end())
{
const instruction* p_iter = &(*iter);
interval_ptr def_interval = nullptr;
bool is_dead = false;
if(instr2_live.find(p_iter) != instr2_live.end())
def_interval = instr2_live[p_iter];
bool is_lit = is_literal(iter);
if(is_allocate(iter) || is_lit)
{
def_interval = instr2_live[p_iter];
bool is_lit = is_literal(iter);
if(is_allocate(iter) || is_lit)
{
live_range& range = def_interval->segment;
def_interval->result = iter->result;
def_interval->is_literal = is_lit;
alloc_queue.push(def_interval);
range.begin = cur_points;
range.size = (iter->result).bytes();
live_set.erase(range.vn);
}
live_range& range = def_interval->segment;
def_interval->result = iter->result;
def_interval->is_literal = is_lit;
alloc_queue.push(def_interval);
range.begin = cur_points;
range.size = (iter->result).bytes();
live_set.erase(range.vn);
}
else if(!is_param(iter) && !is_outline(iter) && !is_check_context(iter))
}
else if(!is_param(iter) && !is_outline(iter) && !is_check_context(iter))
{
is_dead = true;
}
int tie_ndx = get_input_tie_ndx(iter);
int cnt = -1;
for(auto&& arg : iter->arguments)
{
cnt++;
if(is_param(arg) || is_outline(arg))
{
is_dead = true;
if(is_output_param(arg))
is_dead = false;
continue;
}
int tie_ndx = get_input_tie_ndx(iter);
if(!iter->arguments.empty())
const instruction* p_arg = &(*arg);
if(cnt == tie_ndx)
{
int cnt = -1;
for(auto&& arg : iter->arguments)
{
cnt++;
if(is_param(arg) || is_outline(arg))
{
if(is_output_param(arg))
is_dead = false;
continue;
}
const instruction* p_arg = &(*arg);
if(cnt == tie_ndx)
{
// input memory is used as this instruction's output.
// def is considered as use. Coalesce the live intervals.
#ifndef NDEBUG
assert(def_interval != nullptr);
#endif
def_interval->add_use(cur_points);
instr2_live[p_arg] = def_interval;
}
else if(instr2_live.find(p_arg) == instr2_live.end())
{
// First time see a use, create a live interval.
int id = num_of_lives++;
interval_ptr interval = &(live_intervals[id]);
interval->id = id;
interval->segment.end = cur_points;
interval->segment.vn = ++max_value_number;
interval->add_use(cur_points);
instr2_live[p_arg] = interval;
add_conflicts(live_set, max_value_number);
live_set.insert(max_value_number);
live_ranges[max_value_number] = &(interval->segment);
}
else
{
interval_ptr interval = instr2_live[p_arg];
interval->add_use(cur_points);
#ifndef NDEBUG
assert(live_set.find(interval->id) != live_set.end());
#endif
}
}
// input memory is used as this instruction's output.
// def is considered as use. Coalesce the live intervals.
assert(def_interval != nullptr);
def_interval->add_use(cur_points);
instr2_live[p_arg] = def_interval;
}
if(is_dead)
dead_instrs.push_back(iter);
cur_points -= 2;
iter = std::prev(iter);
} while(iter != begin);
}
else if(instr2_live.find(p_arg) == instr2_live.end())
{
// First time see a use, create a live interval.
int id = num_of_lives++;
interval_ptr interval = &(live_intervals[id]);
interval->id = id;
interval->segment.end = cur_points;
interval->segment.vn = ++max_value_number;
interval->add_use(cur_points);
instr2_live[p_arg] = interval;
add_conflicts(live_set, max_value_number);
live_set.insert(max_value_number);
live_ranges[max_value_number] = &(interval->segment);
}
else
{
interval_ptr interval = instr2_live[p_arg];
interval->add_use(cur_points);
assert(live_set.find(interval->id) != live_set.end());
}
}
if(is_dead)
dead_instrs.push_back(iter);
cur_points -= 2;
iter = std::prev(iter);
} while(iter != begin);
}
void memory_coloring_impl::rewrite()
......@@ -190,9 +184,7 @@ void memory_coloring_impl::rewrite()
interval_ptr interval = instr2_live[p_iter];
if(interval->get_offset() == InvalidOffset)
{
#ifndef NDEBUG
assert((interval->get_begin() == InvalidOffset) || interval->result.bytes() == 0);
#endif
continue;
}
std::size_t offset = interval->get_offset();
......@@ -209,11 +201,11 @@ void memory_coloring_impl::rewrite()
}
}
}
DEBUG(dump("---After rewrite---"));
DEBUG(dump_program());
MIGRAPH_DEBUG(dump("---After rewrite---"));
MIGRAPH_DEBUG(dump_program());
}
#ifdef DEBUG_OPT
#ifdef MIGRAPH_DEBUG_OPT
void memory_coloring_impl::dump(const std::string str) { std::cout << str << std::endl; }
void memory_coloring_impl::dump_program() { std::cout << *p_program << std::endl; }
......
......@@ -4,7 +4,7 @@
namespace migraph {
#define InvalidOffset (-1)
static const int InvalidOffset = -1;
struct live_range
{
......@@ -13,7 +13,7 @@ struct live_range
long long offset; // offset to base pointer of allocated memory trunk.
int vn; // value number that identifies this live_range.
long long size; // size of required memory in bytes
#ifdef DEBUG_OPT
#ifdef MIGRAPH_DEBUG_OPT
void dump();
#endif
};
......@@ -31,7 +31,7 @@ struct live_interval
int get_end() const { return segment.end; }
long long get_offset() const { return segment.offset; }
#ifdef DEBUG_OPT
#ifdef MIGRAPH_DEBUG_OPT
void dump();
#endif
......@@ -42,12 +42,11 @@ struct live_interval
bool is_literal;
};
#define interval_ptr live_interval*
typedef live_interval* interval_ptr;
struct memory_coloring_impl
{
memory_coloring_impl(program* p) : p_program(p) { init(); }
void init()
memory_coloring_impl(program* p) : p_program(p)
{
instr2_live.clear();
live_ranges.clear();
......@@ -70,16 +69,19 @@ struct memory_coloring_impl
void rewrite();
private:
bool is_param(const instruction_ref ins) { return ins->op.name() == "@param"; }
bool is_output_param(const instruction_ref ins)
static bool is_param(const instruction_ref ins) { return ins->op.name() == "@param"; }
static bool is_output_param(const instruction_ref ins)
{
return is_param(ins) && any_cast<builtin::param>(ins->op).parameter == "output";
}
bool is_allocate(const instruction_ref ins) { return ins->op.name() == "hip::allocate"; }
bool is_outline(const instruction_ref ins) { return ins->op.name() == "@outline"; }
bool is_literal(const instruction_ref ins) { return ins->op.name() == "@literal"; }
bool is_check_context(const instruction_ref ins) { return ins->op.name() == "check_context"; }
bool is_transpose(const instruction_ref ins) { return ins->op.name() == "transpose"; }
static bool is_allocate(const instruction_ref ins) { return ins->op.name() == "hip::allocate"; }
static bool is_outline(const instruction_ref ins) { return ins->op.name() == "@outline"; }
static bool is_literal(const instruction_ref ins) { return ins->op.name() == "@literal"; }
static bool is_check_context(const instruction_ref ins)
{
return ins->op.name() == "check_context";
}
static bool is_transpose(const instruction_ref ins) { return ins->op.name() == "transpose"; }
int get_input_tie_ndx(const instruction_ref ins)
{
if(is_transpose(ins))
......@@ -94,8 +96,8 @@ struct memory_coloring_impl
}
return last_allocate;
}
#ifdef DEBUG_OPT
bool is_disjoin(live_range& range1, live_range& range2)
#ifdef MIGRAPH_DEBUG_OPT
static bool is_disjoin(live_range& range1, live_range& range2)
{
long long end1 = range1.offset + range1.size - 1;
long long end2 = range2.offset + range2.size - 1;
......
......@@ -320,10 +320,6 @@ argument program::eval(std::unordered_map<std::string, argument> params) const
return generic_eval(
*this, this->impl->ctx, std::move(params), [](auto&, auto f) { return f(); });
}
int program::get_size() const
{
return (*impl).instructions.size();
}
double common_average(const std::vector<double>& v)
{
......@@ -428,5 +424,4 @@ std::ostream& operator<<(std::ostream& os, const program& p)
print_program(os, p, [](auto&&...) {});
return os;
}
} // namespace migraph
......@@ -58,7 +58,9 @@ std::vector<T> read_from_gpu(const void* x, std::size_t sz)
hip_ptr write_to_gpu(const void* x, std::size_t sz, bool host = false)
{
auto result = allocate_gpu(sz, host);
// gpu_sync();
auto status = hipMemcpy(result.get(), x, sz, hipMemcpyHostToDevice);
if(status != hipSuccess)
MIGRAPH_THROW("Copy to gpu failed: " + hip_error(status));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment