Commit 39eedcd9 authored by mei-ye's avatar mei-ye
Browse files

remove passing in address

parent b1e097b3
......@@ -28,7 +28,7 @@ struct program
program& operator=(program&&) noexcept;
~program() noexcept;
using parameter_map = migraph::parameter_map;
using parameter_map = std::unordered_map<std::string, argument>;
template <class... Ts>
instruction_ref add_instruction(operation op, Ts... args)
......@@ -91,7 +91,7 @@ struct program
instruction_ref validate() const;
void compile(const target& t, tracer trace = tracer{}, parameter_map params = parameter_map());
void compile(const target& t, tracer trace = tracer{});
void perf_report(std::ostream& os, std::size_t n, parameter_map params) const;
......
......@@ -8,13 +8,10 @@
#include <type_traits>
#include <utility>
#include <vector>
#include <unordered_map>
#include <migraph/context.hpp>
#include <migraph/pass.hpp>
#include <migraph/argument.hpp>
namespace migraph {
using parameter_map = std::unordered_map<std::string, argument>;
#ifdef DOXYGEN
......@@ -36,7 +33,7 @@ struct target
* @brief Construct a context for the target.
* @return The context to be used during compilation and execution.
*/
context get_context(parameter_map params = parameter_map()) const;
context get_context() const;
};
#else
......@@ -122,10 +119,10 @@ struct target
return (*this).private_detail_te_get_handle().get_passes(ctx);
}
context get_context(parameter_map params = parameter_map()) const
context get_context() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().get_context(std::move(params));
return (*this).private_detail_te_get_handle().get_context();
}
private:
......@@ -137,7 +134,7 @@ struct target
virtual std::string name() const = 0;
virtual std::vector<pass> get_passes(context& ctx) const = 0;
virtual context get_context(parameter_map params) const = 0;
virtual context get_context() const = 0;
};
template <typename PrivateDetailTypeErasedT>
......@@ -176,10 +173,7 @@ struct target
return private_detail_te_value.get_passes(ctx);
}
context get_context(parameter_map params) const override
{
return private_detail_te_value.get_context(params);
}
context get_context() const override { return private_detail_te_value.get_context(); }
PrivateDetailTypeErasedT private_detail_te_value;
};
......
......@@ -257,10 +257,10 @@ instruction_ref program::validate() const
[&](const instruction& i) { return !i.valid(impl->instructions.begin()); });
}
void program::compile(const target& t, tracer trace, parameter_map params)
void program::compile(const target& t, tracer trace)
{
assert(this->validate() == impl->instructions.end());
this->impl->ctx = t.get_context(std::move(params));
this->impl->ctx = t.get_context();
if(not trace.enabled() and enabled(MIGRAPH_TRACE_COMPILE{}))
trace = tracer{std::cout};
trace(*this);
......
......@@ -6,11 +6,9 @@
namespace migraph {
namespace cpu {
using parameter_map = std::unordered_map<std::string, argument>;
struct context
{
parameter_map params;
void finish() const {}
};
} // namespace cpu
......
......@@ -11,10 +11,7 @@ struct cpu_target
{
std::string name() const;
std::vector<pass> get_passes(migraph::context& ctx) const;
migraph::context get_context(parameter_map params = parameter_map()) const
{
return context{std::move(params)};
}
migraph::context get_context() const { return context{}; }
};
} // namespace cpu
} // namespace migraph
......
......@@ -5,16 +5,13 @@
#include <migraph/gpu/rocblas.hpp>
#include <migraph/gpu/hip.hpp>
#include <unordered_map>
namespace migraph {
namespace gpu {
using parameter_map = std::unordered_map<std::string, argument>;
struct context
{
shared<miopen_handle> handle;
shared<rocblas_handle_ptr> rbhandle;
parameter_map params;
argument scratch;
std::vector<argument> literals{};
void finish() const { gpu_sync(); }
......
......@@ -10,7 +10,7 @@ struct target
{
std::string name() const;
std::vector<pass> get_passes(migraph::context& gctx) const;
migraph::context get_context(parameter_map params = parameter_map()) const;
migraph::context get_context() const;
};
} // namespace gpu
} // namespace migraph
......
......@@ -29,32 +29,10 @@ void lowering_memory_coloring::apply(program& p) const
if(scratch_ins == p.end())
return;
bool can_resolve_addr = false;
argument base_ptr;
shape s_scratch = scratch_ins->result;
if(ctx->params.find("scratch") == ctx->params.end())
{
// scratch memory is not passed in, allocate memory here.
can_resolve_addr = true;
base_ptr = allocate_gpu(s_scratch, false);
}
else
{
argument a = ctx->params["scratch"];
assert((a.get_shape().bytes() >= s_scratch.bytes()) && "insufficent scratch memory");
if(!a.empty())
{
// scratch memory is passed in and already has a known address.
can_resolve_addr = true;
base_ptr = a;
}
}
if(can_resolve_addr)
{
ctx->scratch = base_ptr;
scratch_ins = p.replace_instruction(scratch_ins, gen_base_addr{s_scratch});
}
shape s_scratch = scratch_ins->result;
argument base_ptr = allocate_gpu(s_scratch, false);
ctx->scratch = base_ptr;
scratch_ins = p.replace_instruction(scratch_ins, gen_base_addr{s_scratch});
for(auto ins : iterator_for(p))
{
......@@ -69,7 +47,7 @@ void lowering_memory_coloring::apply(program& p) const
auto&& a = any_cast<write_literal>(ins->op);
std::size_t offset = a.offset;
if(can_resolve_addr && a.pre_copy)
if(a.pre_copy)
{
char* dst = base_ptr.data() + offset;
const char* src = arg1->lit.data();
......
......@@ -47,12 +47,10 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
std::string target::name() const { return "miopen"; }
migraph::context target::get_context(parameter_map params) const
migraph::context target::get_context() const
{
return context{share(make_obj<miopen_handle>(&miopenCreate)),
share(create_rocblas_handle_ptr()),
params,
{}};
return context{
share(make_obj<miopen_handle>(&miopenCreate)), share(create_rocblas_handle_ptr()), {}};
}
} // namespace gpu
} // namespace migraph
......@@ -10,7 +10,7 @@ struct contiguous_target
{
return {migraph::auto_contiguous{}};
}
migraph::context get_context(migraph::parameter_map) const { return {}; }
migraph::context get_context() const { return {}; }
};
void literal_broadcast()
......
......@@ -12,7 +12,7 @@ struct eliminate_allocation_target
{
return {migraph::eliminate_allocation{"allocate", align}, migraph::dead_code_elimination{}};
}
migraph::context get_context(migraph::parameter_map) const { return {}; }
migraph::context get_context() const { return {}; }
};
struct allocate
......
......@@ -10,7 +10,7 @@ struct id_target
{
std::string name() const { return "id"; }
std::vector<migraph::pass> get_passes(migraph::context&) const { return {}; }
migraph::context get_context(migraph::parameter_map) const { return {}; }
migraph::context get_context() const { return {}; }
};
struct reverse_pass
......@@ -37,7 +37,7 @@ struct reverse_target
{
std::string name() const { return "reverse"; }
std::vector<migraph::pass> get_passes(migraph::context&) const { return {reverse_pass{}}; }
migraph::context get_context(migraph::parameter_map) const { return {}; }
migraph::context get_context() const { return {}; }
};
struct double_reverse_target
......@@ -47,7 +47,7 @@ struct double_reverse_target
{
return {reverse_pass{}, reverse_pass{}};
}
migraph::context get_context(migraph::parameter_map) const { return {}; }
migraph::context get_context() const { return {}; }
};
void literal_test1()
......
......@@ -10,7 +10,7 @@ struct memory_coloring_target
{
return {migraph::memory_coloring{}};
}
migraph::context get_context(migraph::parameter_map) const { return {}; }
migraph::context get_context() const { return {}; }
};
int main()
......
......@@ -11,7 +11,7 @@ struct simplify_reshapes_target
{
return {migraph::simplify_reshapes{}, migraph::dead_code_elimination{}};
}
migraph::context get_context(migraph::parameter_map) const { return {}; }
migraph::context get_context() const { return {}; }
};
void double_contig()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment