"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "0b4f2f8d24dd9be25f6039eb90eb70e0d6c352f0"
Commit d877a3fb authored by mei-ye's avatar mei-ye
Browse files

memory coloring pass

parents 58681660 80587d4c
......@@ -18,7 +18,7 @@ else()
set(MIGRAPH_ENABLE_GPU Off CACHE BOOL "")
endif()
add_compile_options(-std=c++14)
add_compile_options(-std=c++14 -g -O0)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
include(EnableCompilerWarnings)
......
......@@ -8,6 +8,8 @@ add_library(migraph
program.cpp
shape.cpp
simplify_reshapes.cpp
opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp
)
rocm_clang_tidy_check(migraph)
target_include_directories(migraph PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
......@@ -17,3 +19,9 @@ add_subdirectory(targets/cpu)
if(MIGRAPH_ENABLE_GPU)
add_subdirectory(targets/gpu)
endif()
install (TARGETS migraph
LIBRARY DESTINATION /opt/rocm/lib)
install (DIRECTORY include/migraph DESTINATION /opt/rocm/include)
#ifndef MIGRAPH_GUARD_RTGLIB_MEMORY_COLORING_HPP
#define MIGRAPH_GUARD_RTGLIB_MEMORY_COLORING_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
namespace migraph {
struct program;
struct memory_coloring
{
std::string name() const { return "memory coloring"; }
void apply(program& p) const;
};
} // namespace migraph
#endif
......@@ -534,6 +534,31 @@ struct div : binary
std::string name() const { return "div"; }
};
struct get_mem_ptr
{
std::string name() const { return "get_mem_ptr:" + std::to_string(offset); }
shape compute_shape(std::vector<shape> inputs) const
{
return inputs.at(1);
}
argument compute(context&, shape output_shape, std::vector<argument> args) const {
return {output_shape, args.at(0).data() + offset};
}
std::size_t offset = 0;
};
struct write_literal
{
std::string name() const { return "write_literal"; }
shape compute_shape(std::vector<shape> inputs) const
{
return inputs.at(2);
}
argument compute(context&, shape output_shape, std::vector<argument> args) const {
assert(false);
}
};
struct outline
{
shape s;
......
......@@ -93,6 +93,8 @@ struct program
void perf_report(std::ostream& os, std::size_t n, parameter_map params) const;
int get_size() const;
friend std::ostream& operator<<(std::ostream& os, const program& p);
friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); }
......
#ifndef MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP
#define MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP
#include <migraph/program.hpp>
#include <migraph/stringutils.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
#include <set>
#include <list>
#include <vector>
#include <queue>
#define DEBUG_OPT
#ifdef DEBUG_OPT
#define DEBUG(s) s
#else
#define DEBUG(s)
#endif // DEBUG_OPT
#endif // MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP
#include <migraph/memory_coloring.hpp>
#include "memory_coloring_impl.hpp"
namespace migraph {
void memory_coloring::apply(program &p) const
{
memory_coloring_impl opt(&p);
opt.run();
}
} // namespace migraph
#include "memory_coloring_impl.hpp"
namespace migraph {
void memory_coloring_impl::run()
{
build();
if (num_of_lives != 0) {
DEBUG(dump("---Before memory coloring---"));
DEBUG(dump_program());
DEBUG(dump_intervals());
// Coloring
while (!alloc_queue.empty()) {
T_live_interval* interval = alloc_queue.top();
allocate(interval);
alloc_queue.pop();
}
rewrite();
DEBUG(verify());
for (int i = 0; i < num_of_lives; ++i) {
free(live_intervals[i]);
}
}
}
bool memory_coloring_impl::allocate(T_live_interval* interval)
{
shape s = interval->result;
std::size_t size = s.bytes();
if (size == 0)
return false;
std::size_t element_size = size / s.elements();
T_live_range& segment = interval->segment;
int vn = segment.vn;
std::priority_queue<T_live_range*, std::vector<T_live_range*>, ordering> conflict_queue;
std::unordered_map<long long, T_live_range*> offset2Live;
offset2Live.clear();
if (conflict_table.find(vn) != conflict_table.end()) {
std::set<int>& vn_set = conflict_table[vn];
for (auto iter = vn_set.begin(), end = vn_set.end(); iter != end; ++iter) {
T_live_range* range = live_ranges[*iter];
long long offset = range->offset;
if (offset != InvalidOffset) {
conflict_queue.push(range);
if (offset2Live.find(offset) == offset2Live.end()) {
offset2Live[offset] = range;
} else {
T_live_range* prev = offset2Live[offset];
assert(prev->offset == offset);
if (prev->size < range->size)
offset2Live[offset] = range;
}
}
}
}
long long offset = 0;
while (!conflict_queue.empty()) {
T_live_range* range = conflict_queue.top();
long long cur_offset = range->offset;
if (offset2Live[cur_offset] == range) {
if ((cur_offset > offset) && (cur_offset - offset) >= size) {
break;
}
offset = cur_offset + range->size;
if ((offset % element_size) != 0)
offset += (element_size - (offset % element_size));
}
conflict_queue.pop();
}
segment.offset = offset;
DEBUG(segment.dump());
required_bytes = std::max(required_bytes, offset + segment.size);
return true;
}
void memory_coloring_impl::build()
{
int num_of_instrs = p_program->get_size();
if (num_of_instrs == 0)
return;
int cur_points = num_of_instrs * 2;
instruction_ref iter = std::prev(p_program->end());
instruction_ref begin = p_program->begin();
std::vector<instruction_ref> dead_instrs;
std::set<int> live_set;
// Build live intervals.
do {
const instruction* p_iter = &(*iter);
T_live_interval* def_interval = nullptr;
bool isDead = false;
if (instr2Live.find(p_iter) != instr2Live.end()) {
def_interval = instr2Live[p_iter];
bool isLit = isLiteral(iter);
if (isAllocate(iter) || isLit) {
T_live_range& range = def_interval->segment;
def_interval->result = iter->result;
def_interval->isLiteral = isLit;
alloc_queue.push(def_interval);
range.begin = cur_points;
range.size = (iter->result).bytes();
live_set.erase(range.vn);
}
} else if (!isParam(iter) && !isOutline(iter) && !isCheckContext(iter)) {
isDead = true;
}
int tieNdx = getInputTieNdx(iter);
if (!iter->arguments.empty()) {
int cnt = -1;
for (auto&& arg : iter->arguments) {
cnt++;
if (isParam(arg) || isOutline(arg)) {
if (isOutputParam(arg))
isDead = false;
continue;
}
const instruction* p_arg = &(*arg);
if (cnt == tieNdx) {
// input memory is used as this instruction's output.
// def is considered as use. Coalesce the live intervals.
def_interval->addUse(cur_points);
instr2Live[p_arg] = def_interval;
} else if (instr2Live.find(p_arg) == instr2Live.end()) {
// First time see a use, create a live interval.
int id = num_of_lives++;
T_live_interval* interval = new live_interval();
interval->id = id;
interval->segment.end = cur_points;
interval->segment.vn = ++max_value_number;
interval->addUse(cur_points);
instr2Live[p_arg] = interval;
addConflicts(live_set, max_value_number);
live_set.insert(max_value_number);
live_intervals[id] = interval;
live_ranges[max_value_number] = &(interval->segment);
} else {
T_live_interval* interval = instr2Live[p_arg];
interval->addUse(cur_points);
DEBUG(assert(live_set.find(interval->id) != live_set.end()));
}
}
}
if (isDead)
dead_instrs.push_back(iter);
cur_points -= 2;
iter = std::prev(iter);
} while (iter != begin);
}
void memory_coloring_impl::rewrite()
{
instruction_ref end = p_program->end();
instruction_ref scratch_param = end;
std::vector<std::size_t> dims;
dims.push_back(required_bytes/sizeof(float));
shape s = {shape::float_type, dims};
scratch_param = p_program->add_parameter("scratch", s);
for (auto ins : iterator_for(*p_program)) {
const instruction* p_iter = &(*ins);
if (instr2Live.find(p_iter) != instr2Live.end()) {
T_live_interval* interval = instr2Live[p_iter];
if (interval->get_offset() == InvalidOffset) {
DEBUG(assert((interval->get_begin() == InvalidOffset) ||
interval->result.bytes() == 0));
continue;
}
std::size_t offset = interval->get_offset();
if (isAllocate(ins)) {
p_program->replace_instruction(ins, get_mem_ptr{offset}, scratch_param, ins->arguments.at(0));
} else if (isLiteral(ins)) {
auto pre = p_program->add_literal(ins->lit);
auto index = p_program->add_literal(offset);
p_program->replace_instruction(ins, write_literal{}, scratch_param, index, pre);
}
}
}
DEBUG(dump("---After rewrite---"));
DEBUG(dump_program());
}
#ifdef DEBUG_OPT
void memory_coloring_impl::dump(std::string str)
{
std::cout << str << std::endl;
}
void memory_coloring_impl::dump_program()
{
std::cout << *p_program << std::endl;
}
void memory_coloring_impl::dump_intervals()
{
if (num_of_lives > 0) {
std::cout << "---live intervals ---" << std::endl;
for (int i = 0; i < num_of_lives; ++i) {
T_live_interval* interval = live_intervals[i];
interval->dump();
}
std::cout << "---conflict table---" << std::endl;
for (int i = 0; i <= max_value_number; ++i) {
std::cout << " segment:" << i;
std::cout << " =>";
std::set<int>& table = conflict_table[i];
for (auto iter = table.begin(), end = table.end(); iter != end; ++iter) {
std::cout << (*iter) << ",";
}
}
std::cout << std::endl;
}
}
void memory_coloring_impl::verify()
{
if (num_of_lives > 0) {
for (int i = 0; i < num_of_lives; ++i) {
T_live_interval* interval = live_intervals[i];
T_live_range& segment = interval->segment;
if (segment.offset == InvalidOffset)
continue;
int vn = segment.vn;
if (conflict_table.find(vn) != conflict_table.end()) {
std::set<int>& vn_set = conflict_table[vn];
for (auto iter = vn_set.begin(), end = vn_set.end(); iter != end; ++iter) {
T_live_range* range = live_ranges[*iter];
if (range->offset == InvalidOffset)
continue;
if (!isDisjoin(*range, segment))
assert(false);
}
}
}
}
}
#define GET_INS_ENUM(x) (((x) >> 1) - 1)
void live_range::dump()
{
std::cout << " segment:" << vn;
std::cout << " [" << GET_INS_ENUM(begin) << ", " << GET_INS_ENUM(end) << "]";
if (offset != InvalidOffset) {
std::cout << " mem:";
std::cout << " [" << offset << "," << offset + size - 1 << "]";
}
std::cout << std::endl;
}
void live_interval::dump()
{
std::cout << "id:" << id;
segment.dump();
std::cout << " uses:";
for (auto iter = use_points.begin(), end = use_points.end(); iter != end; ++iter) {
int& use = *iter;
std::cout << " " << GET_INS_ENUM(use) << ",";
}
if (isLiteral)
std::cout << " literal";
std::cout << " " << result;
std::cout << std::endl;
}
#endif
} // namespace migraph
#ifndef MIGRAPH_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#define MIGRAPH_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#include "common_header.hpp"
namespace migraph {
#define InvalidOffset -1
typedef struct live_range {
int begin; // begin point in the instruction stream.
int end; // end point in the instruction stream.
long long offset; // offset to base pointer of allocated memory trunk.
int vn; // value number that identifies this live_range.
long long size; // size of required memory in bytes
#ifdef DEBUG_OPT
void dump();
#endif
} T_live_range;
typedef struct live_interval {
explicit live_interval() { init(); }
void init() {
id = -1; isLiteral = false;
segment = { -1, -1, InvalidOffset, -1, 0};
}
void addUse(int use) { use_points.push_front(use); }
int get_begin() const { return segment.begin; }
int get_end() const { return segment.end; }
long long get_offset() const { return segment.offset; }
#ifdef DEBUG_OPT
void dump();
#endif
T_live_range segment;
int id;
std::list<int> use_points;
shape result;
bool isLiteral;
} T_live_interval;
struct memory_coloring_impl {
explicit memory_coloring_impl(program *p) : p_program(p)
{
instr2Live.clear();
live_intervals.clear();
live_ranges.clear();
conflict_table.clear();
num_of_lives = 0;
max_value_number = -1;
required_bytes = 0;
}
bool allocate(T_live_interval*);
void addConflicts(std::set<int>& live_set, int val)
{
for (auto iter = live_set.begin(), end = live_set.end(); iter != end; ++ iter) {
conflict_table[*iter].insert(val);
conflict_table[val].insert(*iter);
}
}
void build();
void run();
void rewrite();
private:
bool isParam(const instruction_ref ins) { return ins->op.name() == "@param"; }
bool isOutputParam(const instruction_ref ins)
{
return isParam(ins) && any_cast<builtin::param>(ins->op).parameter == "output";
}
bool isScratchParam(const instruction_ref ins)
{
return isParam(ins) && any_cast<builtin::param>(ins->op).parameter == "scratch";
}
bool isAllocate(const instruction_ref ins) { return ins->op.name() == "hip::allocate"; }
bool isOutline(const instruction_ref ins) { return ins->op.name() == "@outline"; }
bool isLiteral(const instruction_ref ins) { return ins->op.name() == "@literal"; }
bool isCheckContext(const instruction_ref ins) { return ins->op.name() == "check_context"; }
bool isTranspose(const instruction_ref ins) { return ins->op.name() == "transpose"; }
int getInputTieNdx(const instruction_ref ins) {
if (isTranspose(ins))
return 0;
int cnt = -1;
int last_allocate = -1;
for (auto&& arg : ins->arguments) {
cnt++;
if (isAllocate(arg))
last_allocate = cnt;
}
return last_allocate;
}
bool isDisjoin(T_live_range& range1, T_live_range& range2) {
long long end1 = range1.offset + range1.size - 1;
long long end2 = range2.offset + range2.size - 1;
return ((end1 < range2.offset) || (end2 < range1.offset));
}
#ifdef DEBUG_OPT
void dump(std::string);
void dump_program();
void dump_intervals();
void verify();
#endif
struct ordering {
bool operator() (const T_live_interval* I1, const T_live_interval* I2) const
{
int len1 = I1->get_end() - I1->get_begin();
int len2 = I2->get_end() - I2->get_begin();
if (len1 != len2) {
return (len1 < len2) ? true : false;
} else if (I1->result.bytes() != I2->result.bytes()) {
return (I1->result.bytes() < I2->result.bytes()) ? true : false;
} else {
return I1->id > I2->id;
}
}
bool operator() (const T_live_range* I1, const T_live_range* I2) const
{
return (I1->offset > I2->offset);
}
};
program* p_program;
std::unordered_map<const instruction*, T_live_interval*> instr2Live;
// Map live interval Id to live interval.
std::unordered_map<int, T_live_interval*> live_intervals;
// Map live range value number to live range.
std::unordered_map<int, T_live_range*> live_ranges;
// Map live range value number to a set of conflicting live ranges' value numbers.
std::unordered_map<int, std::set<int>> conflict_table;
// Priority queue for coloring.
std::priority_queue<T_live_interval*, std::vector<T_live_interval*>, ordering> alloc_queue;
int num_of_lives;
int max_value_number;
long long required_bytes;
};
} // namespace migraph
#endif
......@@ -320,6 +320,10 @@ argument program::eval(std::unordered_map<std::string, argument> params) const
return generic_eval(
*this, this->impl->ctx, std::move(params), [](auto&, auto f) { return f(); });
}
int program::get_size() const
{
return (*impl).instructions.size();
}
double common_average(const std::vector<double>& v)
{
......
......@@ -13,3 +13,8 @@ target_link_libraries(migraph_cpu migraph Threads::Threads)
target_include_directories(migraph_cpu PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
target_include_directories(migraph_cpu PRIVATE ${BLAZE_INCLUDE})
target_compile_definitions(migraph_cpu PRIVATE -DBLAZE_USE_CPP_THREADS)
install (TARGETS migraph_cpu
LIBRARY DESTINATION /opt/rocm/lib)
install (DIRECTORY include/migraph DESTINATION /opt/rocm/include)
......@@ -29,3 +29,11 @@ add_library(migraph_gpu
rocm_clang_tidy_check(migraph_gpu)
target_link_libraries(migraph_gpu migraph MIOpen migraph_device roc::rocblas)
target_include_directories(migraph_gpu PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
install (TARGETS migraph_gpu
LIBRARY DESTINATION /opt/rocm/lib)
install (DIRECTORY include/migraph DESTINATION /opt/rocm/include)
install (TARGETS migraph_device
LIBRARY DESTINATION /opt/rocm/lib)
install (DIRECTORY include/migraph DESTINATION /opt/rocm/include)
......@@ -90,6 +90,11 @@ argument from_gpu(argument arg)
void gpu_sync() { hipDeviceSynchronize(); }
void copy_to_gpu(char* dst, const char* src, std::size_t size)
{
hipMemcpy(dst, src, size, hipMemcpyHostToDevice);
}
} // namespace gpu
} // namespace migraph
......@@ -14,6 +14,7 @@ migraph::argument to_gpu(migraph::argument arg, bool host = false);
migraph::argument from_gpu(migraph::argument arg);
void gpu_sync();
void copy_to_gpu(char* dst, const char* src, std::size_t size);
struct hip_allocate
{
......@@ -81,6 +82,23 @@ struct hip_write
}
};
struct hip_memcpy
{
std::string name() const { return "hip_memcpy"; }
shape compute_shape(std::vector<shape> inputs) const
{
return inputs.at(2);
}
argument compute(context&, shape output_shape, std::vector<argument> args) const {
std::size_t * p_data = reinterpret_cast<std::size_t*>(args.at(1).data());
char* dst = args.at(0).data() + p_data[0];
const char* src = args.at(2).data();
std::size_t size = args.at(2).get_shape().bytes();
copy_to_gpu(dst, src, size);
return {output_shape, dst};
}
};
} // namespace gpu
} // namespace migraph
......
#include <migraph/gpu/target.hpp>
#include <migraph/gpu/lowering.hpp>
#include <migraph/memory_coloring.hpp>
#include <migraph/gpu/write_literals.hpp>
#include <migraph/gpu/context.hpp>
#include <migraph/gpu/eliminate_workspace.hpp>
......@@ -24,11 +25,13 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
simplify_reshapes{},
dead_code_elimination{},
lowering{ctx},
eliminate_workspace{},
memory_coloring{},
// eliminate_workspace{},
eliminate_contiguous{},
dead_code_elimination{},
write_literals{&ctx},
eliminate_allocation{},
// eliminate_allocation{},
check_context<context>{},
dead_code_elimination{}
};
......
......@@ -28,6 +28,7 @@ void write_literals::apply(program& p) const
assert(ctx != nullptr);
for(auto ins : iterator_for(p))
{
#if 0
if(ins->op.name() == "@literal")
{
argument a = to_gpu(ins->lit.get_argument());
......@@ -35,6 +36,11 @@ void write_literals::apply(program& p) const
ctx->literals.push_back(a);
p.replace_instruction(ins, hip_load_literal{a.get_shape(), n});
}
#else
if (ins->op.name() == "write_literal") {
p.replace_instruction(ins, hip_memcpy{}, ins->arguments);
}
#endif
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment