Commit 39eedcd9 authored by mei-ye's avatar mei-ye
Browse files

remove passing in address

parent b1e097b3
...@@ -28,7 +28,7 @@ struct program ...@@ -28,7 +28,7 @@ struct program
program& operator=(program&&) noexcept; program& operator=(program&&) noexcept;
~program() noexcept; ~program() noexcept;
using parameter_map = migraph::parameter_map; using parameter_map = std::unordered_map<std::string, argument>;
template <class... Ts> template <class... Ts>
instruction_ref add_instruction(operation op, Ts... args) instruction_ref add_instruction(operation op, Ts... args)
...@@ -91,7 +91,7 @@ struct program ...@@ -91,7 +91,7 @@ struct program
instruction_ref validate() const; instruction_ref validate() const;
void compile(const target& t, tracer trace = tracer{}, parameter_map params = parameter_map()); void compile(const target& t, tracer trace = tracer{});
void perf_report(std::ostream& os, std::size_t n, parameter_map params) const; void perf_report(std::ostream& os, std::size_t n, parameter_map params) const;
......
...@@ -8,13 +8,10 @@ ...@@ -8,13 +8,10 @@
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <unordered_map>
#include <migraph/context.hpp> #include <migraph/context.hpp>
#include <migraph/pass.hpp> #include <migraph/pass.hpp>
#include <migraph/argument.hpp>
namespace migraph { namespace migraph {
using parameter_map = std::unordered_map<std::string, argument>;
#ifdef DOXYGEN #ifdef DOXYGEN
...@@ -36,7 +33,7 @@ struct target ...@@ -36,7 +33,7 @@ struct target
* @brief Construct a context for the target. * @brief Construct a context for the target.
* @return The context to be used during compilation and execution. * @return The context to be used during compilation and execution.
*/ */
context get_context(parameter_map params = parameter_map()) const; context get_context() const;
}; };
#else #else
...@@ -122,10 +119,10 @@ struct target ...@@ -122,10 +119,10 @@ struct target
return (*this).private_detail_te_get_handle().get_passes(ctx); return (*this).private_detail_te_get_handle().get_passes(ctx);
} }
context get_context(parameter_map params = parameter_map()) const context get_context() const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().get_context(std::move(params)); return (*this).private_detail_te_get_handle().get_context();
} }
private: private:
...@@ -137,7 +134,7 @@ struct target ...@@ -137,7 +134,7 @@ struct target
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual std::vector<pass> get_passes(context& ctx) const = 0; virtual std::vector<pass> get_passes(context& ctx) const = 0;
virtual context get_context(parameter_map params) const = 0; virtual context get_context() const = 0;
}; };
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -176,10 +173,7 @@ struct target ...@@ -176,10 +173,7 @@ struct target
return private_detail_te_value.get_passes(ctx); return private_detail_te_value.get_passes(ctx);
} }
context get_context(parameter_map params) const override context get_context() const override { return private_detail_te_value.get_context(); }
{
return private_detail_te_value.get_context(params);
}
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
}; };
......
...@@ -257,10 +257,10 @@ instruction_ref program::validate() const ...@@ -257,10 +257,10 @@ instruction_ref program::validate() const
[&](const instruction& i) { return !i.valid(impl->instructions.begin()); }); [&](const instruction& i) { return !i.valid(impl->instructions.begin()); });
} }
void program::compile(const target& t, tracer trace, parameter_map params) void program::compile(const target& t, tracer trace)
{ {
assert(this->validate() == impl->instructions.end()); assert(this->validate() == impl->instructions.end());
this->impl->ctx = t.get_context(std::move(params)); this->impl->ctx = t.get_context();
if(not trace.enabled() and enabled(MIGRAPH_TRACE_COMPILE{})) if(not trace.enabled() and enabled(MIGRAPH_TRACE_COMPILE{}))
trace = tracer{std::cout}; trace = tracer{std::cout};
trace(*this); trace(*this);
......
...@@ -6,11 +6,9 @@ ...@@ -6,11 +6,9 @@
namespace migraph { namespace migraph {
namespace cpu { namespace cpu {
using parameter_map = std::unordered_map<std::string, argument>;
struct context struct context
{ {
parameter_map params;
void finish() const {} void finish() const {}
}; };
} // namespace cpu } // namespace cpu
......
...@@ -11,10 +11,7 @@ struct cpu_target ...@@ -11,10 +11,7 @@ struct cpu_target
{ {
std::string name() const; std::string name() const;
std::vector<pass> get_passes(migraph::context& ctx) const; std::vector<pass> get_passes(migraph::context& ctx) const;
migraph::context get_context(parameter_map params = parameter_map()) const migraph::context get_context() const { return context{}; }
{
return context{std::move(params)};
}
}; };
} // namespace cpu } // namespace cpu
} // namespace migraph } // namespace migraph
......
...@@ -5,16 +5,13 @@ ...@@ -5,16 +5,13 @@
#include <migraph/gpu/rocblas.hpp> #include <migraph/gpu/rocblas.hpp>
#include <migraph/gpu/hip.hpp> #include <migraph/gpu/hip.hpp>
#include <unordered_map>
namespace migraph { namespace migraph {
namespace gpu { namespace gpu {
using parameter_map = std::unordered_map<std::string, argument>;
struct context struct context
{ {
shared<miopen_handle> handle; shared<miopen_handle> handle;
shared<rocblas_handle_ptr> rbhandle; shared<rocblas_handle_ptr> rbhandle;
parameter_map params;
argument scratch; argument scratch;
std::vector<argument> literals{}; std::vector<argument> literals{};
void finish() const { gpu_sync(); } void finish() const { gpu_sync(); }
......
...@@ -10,7 +10,7 @@ struct target ...@@ -10,7 +10,7 @@ struct target
{ {
std::string name() const; std::string name() const;
std::vector<pass> get_passes(migraph::context& gctx) const; std::vector<pass> get_passes(migraph::context& gctx) const;
migraph::context get_context(parameter_map params = parameter_map()) const; migraph::context get_context() const;
}; };
} // namespace gpu } // namespace gpu
} // namespace migraph } // namespace migraph
......
...@@ -29,32 +29,10 @@ void lowering_memory_coloring::apply(program& p) const ...@@ -29,32 +29,10 @@ void lowering_memory_coloring::apply(program& p) const
if(scratch_ins == p.end()) if(scratch_ins == p.end())
return; return;
bool can_resolve_addr = false;
argument base_ptr;
shape s_scratch = scratch_ins->result; shape s_scratch = scratch_ins->result;
argument base_ptr = allocate_gpu(s_scratch, false);
if(ctx->params.find("scratch") == ctx->params.end())
{
// scratch memory is not passed in, allocate memory here.
can_resolve_addr = true;
base_ptr = allocate_gpu(s_scratch, false);
}
else
{
argument a = ctx->params["scratch"];
assert((a.get_shape().bytes() >= s_scratch.bytes()) && "insufficent scratch memory");
if(!a.empty())
{
// scratch memory is passed in and already has a known address.
can_resolve_addr = true;
base_ptr = a;
}
}
if(can_resolve_addr)
{
ctx->scratch = base_ptr; ctx->scratch = base_ptr;
scratch_ins = p.replace_instruction(scratch_ins, gen_base_addr{s_scratch}); scratch_ins = p.replace_instruction(scratch_ins, gen_base_addr{s_scratch});
}
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
...@@ -69,7 +47,7 @@ void lowering_memory_coloring::apply(program& p) const ...@@ -69,7 +47,7 @@ void lowering_memory_coloring::apply(program& p) const
auto&& a = any_cast<write_literal>(ins->op); auto&& a = any_cast<write_literal>(ins->op);
std::size_t offset = a.offset; std::size_t offset = a.offset;
if(can_resolve_addr && a.pre_copy) if(a.pre_copy)
{ {
char* dst = base_ptr.data() + offset; char* dst = base_ptr.data() + offset;
const char* src = arg1->lit.data(); const char* src = arg1->lit.data();
......
...@@ -47,12 +47,10 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const ...@@ -47,12 +47,10 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
std::string target::name() const { return "miopen"; } std::string target::name() const { return "miopen"; }
migraph::context target::get_context(parameter_map params) const migraph::context target::get_context() const
{ {
return context{share(make_obj<miopen_handle>(&miopenCreate)), return context{
share(create_rocblas_handle_ptr()), share(make_obj<miopen_handle>(&miopenCreate)), share(create_rocblas_handle_ptr()), {}};
params,
{}};
} }
} // namespace gpu } // namespace gpu
} // namespace migraph } // namespace migraph
...@@ -10,7 +10,7 @@ struct contiguous_target ...@@ -10,7 +10,7 @@ struct contiguous_target
{ {
return {migraph::auto_contiguous{}}; return {migraph::auto_contiguous{}};
} }
migraph::context get_context(migraph::parameter_map) const { return {}; } migraph::context get_context() const { return {}; }
}; };
void literal_broadcast() void literal_broadcast()
......
...@@ -12,7 +12,7 @@ struct eliminate_allocation_target ...@@ -12,7 +12,7 @@ struct eliminate_allocation_target
{ {
return {migraph::eliminate_allocation{"allocate", align}, migraph::dead_code_elimination{}}; return {migraph::eliminate_allocation{"allocate", align}, migraph::dead_code_elimination{}};
} }
migraph::context get_context(migraph::parameter_map) const { return {}; } migraph::context get_context() const { return {}; }
}; };
struct allocate struct allocate
......
...@@ -10,7 +10,7 @@ struct id_target ...@@ -10,7 +10,7 @@ struct id_target
{ {
std::string name() const { return "id"; } std::string name() const { return "id"; }
std::vector<migraph::pass> get_passes(migraph::context&) const { return {}; } std::vector<migraph::pass> get_passes(migraph::context&) const { return {}; }
migraph::context get_context(migraph::parameter_map) const { return {}; } migraph::context get_context() const { return {}; }
}; };
struct reverse_pass struct reverse_pass
...@@ -37,7 +37,7 @@ struct reverse_target ...@@ -37,7 +37,7 @@ struct reverse_target
{ {
std::string name() const { return "reverse"; } std::string name() const { return "reverse"; }
std::vector<migraph::pass> get_passes(migraph::context&) const { return {reverse_pass{}}; } std::vector<migraph::pass> get_passes(migraph::context&) const { return {reverse_pass{}}; }
migraph::context get_context(migraph::parameter_map) const { return {}; } migraph::context get_context() const { return {}; }
}; };
struct double_reverse_target struct double_reverse_target
...@@ -47,7 +47,7 @@ struct double_reverse_target ...@@ -47,7 +47,7 @@ struct double_reverse_target
{ {
return {reverse_pass{}, reverse_pass{}}; return {reverse_pass{}, reverse_pass{}};
} }
migraph::context get_context(migraph::parameter_map) const { return {}; } migraph::context get_context() const { return {}; }
}; };
void literal_test1() void literal_test1()
......
...@@ -10,7 +10,7 @@ struct memory_coloring_target ...@@ -10,7 +10,7 @@ struct memory_coloring_target
{ {
return {migraph::memory_coloring{}}; return {migraph::memory_coloring{}};
} }
migraph::context get_context(migraph::parameter_map) const { return {}; } migraph::context get_context() const { return {}; }
}; };
int main() int main()
......
...@@ -11,7 +11,7 @@ struct simplify_reshapes_target ...@@ -11,7 +11,7 @@ struct simplify_reshapes_target
{ {
return {migraph::simplify_reshapes{}, migraph::dead_code_elimination{}}; return {migraph::simplify_reshapes{}, migraph::dead_code_elimination{}};
} }
migraph::context get_context(migraph::parameter_map) const { return {}; } migraph::context get_context() const { return {}; }
}; };
void double_contig() void double_contig()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment