Commit 029b1e64 authored by mei-ye's avatar mei-ye
Browse files

merge to master

parent 733591e1
#ifndef MIGRAPH_GUARD_RTGLIB_MIOPEN_LOWERING_MEMORY_COLORING_HPP
#define MIGRAPH_GUARD_RTGLIB_MIOPEN_LOWERING_MEMORY_COLORING_HPP
#include <migraph/program.hpp>
#include <migraph/gpu/context.hpp>
namespace migraph {
namespace gpu {
struct lowering_memory_coloring
{
context* ctx = nullptr;
std::string name() const { return "gpu::lowering_memory_coloring"; }
void apply(program& p) const;
};
} // namespace gpu
} // namespace migraph
#endif
...@@ -10,7 +10,7 @@ struct target ...@@ -10,7 +10,7 @@ struct target
{ {
std::string name() const; std::string name() const;
std::vector<pass> get_passes(migraph::context& gctx) const; std::vector<pass> get_passes(migraph::context& gctx) const;
migraph::context get_context() const; migraph::context get_context(parameter_map params = parameter_map()) const;
}; };
} // namespace gpu } // namespace gpu
......
#include <migraph/gpu/lowering_memory_coloring.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/instruction.hpp>
#include <migraph/pass_config.hpp>
namespace migraph {
namespace gpu {
struct gen_base_addr
{
shape s;
std::string name() const { return "gen_base_addr"; }
shape compute_shape(const std::vector<shape>&) const
{
return s;
}
argument compute(const context& ctx, const shape&, const std::vector<argument>&) const
{
return ctx.scratch;
}
};
void lowering_memory_coloring::apply(program& p) const
{
if (enabled(MIGRAPH_DISABLE_MEMORY_COLORING{}))
return;
assert(ctx != nullptr);
auto scratch_ins = p.get_parameter("scratch");
if (scratch_ins == p.end())
return;
bool can_resolve_addr = false;
argument base_ptr;
shape s_scratch = scratch_ins->result;
if (ctx->params.find("scratch") == ctx->params.end()) {
// scratch memory is not passed in, allocate memory here.
can_resolve_addr = true;
base_ptr = allocate_gpu(s_scratch, false);
} else {
argument a = ctx->params["scratch"];
assert((a.get_shape().bytes() >= s_scratch.bytes()) && "insufficent scratch memory");
if (!a.empty()) {
// scratch memory is passed in and already has a known address.
can_resolve_addr = true;
base_ptr = a;
}
}
if (can_resolve_addr) {
ctx->scratch = base_ptr;
scratch_ins = p.replace_instruction(scratch_ins, gen_base_addr{s_scratch});
}
for(auto ins : iterator_for(p))
{
if(ins->op.name() == "write_literal")
{
std::vector<instruction_ref>& args = ins->arguments;
instruction_ref arg0 = args.at(0);
instruction_ref arg1 = args.at(1);
shape s_arg1 = arg1->get_shape();
std::size_t size = s_arg1.bytes();
auto&& a = any_cast<write_literal>(ins->op);
std::size_t offset = a.offset;
if (can_resolve_addr && a.pre_copy) {
char* dst = base_ptr.data() + offset;
const char* src = arg1->lit.data();
copy_to_gpu(dst, src, size);
gpu_sync();
p.replace_instruction(ins, load{s_arg1, offset}, scratch_ins);
p.remove_instruction(arg1);
} else {
p.replace_instruction(ins, hip_memcpy{offset}, arg0, arg1);
}
}
}
// std::cout << p << std::endl;
}
} // namespace gpu
} // namespace migraph
#include <migraph/env.hpp>
namepsace migraph {
MIGRAPH_DECLARE_ENV_VAR(MIGRAPH_DISABLE_MEMORY_COLORING)
} //namspace migraph
#include <migraph/gpu/target.hpp> #include <migraph/gpu/target.hpp>
#include <migraph/gpu/lowering.hpp> #include <migraph/gpu/lowering.hpp>
#include <migraph/memory_coloring.hpp>
#include <migraph/gpu/lowering_memory_coloring.hpp>
#include <migraph/gpu/write_literals.hpp> #include <migraph/gpu/write_literals.hpp>
#include <migraph/gpu/context.hpp> #include <migraph/gpu/context.hpp>
#include <migraph/gpu/eliminate_workspace.hpp> #include <migraph/gpu/eliminate_workspace.hpp>
...@@ -14,7 +16,7 @@ ...@@ -14,7 +16,7 @@
namespace migraph { namespace migraph {
namespace gpu { namespace gpu {
std::vector<pass> target::get_passes(migraph::context& gctx) const std::vector<pass> target::get_passes(migraph::context& gctx) const
{ {
auto& ctx = any_cast<context>(gctx); auto& ctx = any_cast<context>(gctx);
...@@ -28,6 +30,8 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const ...@@ -28,6 +30,8 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
simplify_reshapes{}, simplify_reshapes{},
dead_code_elimination{}, dead_code_elimination{},
lowering{ctx}, lowering{ctx},
memory_coloring{},
lowering_memory_coloring{&ctx},
fuse_ops{}, fuse_ops{},
dead_code_elimination{}, dead_code_elimination{},
eliminate_workspace{}, eliminate_workspace{},
...@@ -43,10 +47,10 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const ...@@ -43,10 +47,10 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
std::string target::name() const { return "miopen"; } std::string target::name() const { return "miopen"; }
migraph::context target::get_context() const migraph::context target::get_context(parameter_map params) const
{ {
return context{share(make_obj<miopen_handle>(&miopenCreate)), return context{share(make_obj<miopen_handle>(&miopenCreate)),
share(create_rocblas_handle_ptr())}; share(create_rocblas_handle_ptr()), params, {}};
} }
} // namespace gpu } // namespace gpu
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraph/iterator_for.hpp> #include <migraph/iterator_for.hpp>
#include <migraph/gpu/hip.hpp> #include <migraph/gpu/hip.hpp>
#include <migraph/instruction.hpp> #include <migraph/instruction.hpp>
#include <migraph/pass_config.hpp>
namespace migraph { namespace migraph {
...@@ -25,6 +26,9 @@ struct hip_load_literal ...@@ -25,6 +26,9 @@ struct hip_load_literal
void write_literals::apply(program& p) const void write_literals::apply(program& p) const
{ {
if (!enabled(MIGRAPH_DISABLE_MEMORY_COLORING{}))
return;
assert(ctx != nullptr); assert(ctx != nullptr);
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
......
...@@ -10,7 +10,7 @@ struct id_target ...@@ -10,7 +10,7 @@ struct id_target
{ {
std::string name() const { return "id"; } std::string name() const { return "id"; }
std::vector<migraph::pass> get_passes(migraph::context&) const { return {}; } std::vector<migraph::pass> get_passes(migraph::context&) const { return {}; }
migraph::context get_context() const { return {}; } migraph::context get_context(migraph::parameter_map) const { return {}; }
}; };
struct reverse_pass struct reverse_pass
...@@ -37,7 +37,7 @@ struct reverse_target ...@@ -37,7 +37,7 @@ struct reverse_target
{ {
std::string name() const { return "reverse"; } std::string name() const { return "reverse"; }
std::vector<migraph::pass> get_passes(migraph::context&) const { return {reverse_pass{}}; } std::vector<migraph::pass> get_passes(migraph::context&) const { return {reverse_pass{}}; }
migraph::context get_context() const { return {}; } migraph::context get_context(migraph::parameter_map) const { return {}; }
}; };
struct double_reverse_target struct double_reverse_target
...@@ -47,7 +47,7 @@ struct double_reverse_target ...@@ -47,7 +47,7 @@ struct double_reverse_target
{ {
return {reverse_pass{}, reverse_pass{}}; return {reverse_pass{}, reverse_pass{}};
} }
migraph::context get_context() const { return {}; } migraph::context get_context(migraph::parameter_map) const { return {}; }
}; };
void literal_test1() void literal_test1()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment