Commit 211bfd05 authored by Paul's avatar Paul
Browse files

Copy literals during compile

parent a398a5c1
...@@ -12,6 +12,7 @@ struct context ...@@ -12,6 +12,7 @@ struct context
{ {
shared<miopen_handle> handle; shared<miopen_handle> handle;
shared<rocblas_handle_ptr> rbhandle; shared<rocblas_handle_ptr> rbhandle;
std::vector<argument> literals{};
void finish() const { gpu_sync(); } void finish() const { gpu_sync(); }
}; };
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define MIGRAPH_GUARD_RTGLIB_MIOPEN_WRITE_LITERALS_HPP #define MIGRAPH_GUARD_RTGLIB_MIOPEN_WRITE_LITERALS_HPP
#include <migraph/program.hpp> #include <migraph/program.hpp>
#include <migraph/gpu/context.hpp>
namespace migraph { namespace migraph {
...@@ -9,6 +10,7 @@ namespace gpu { ...@@ -9,6 +10,7 @@ namespace gpu {
struct write_literals struct write_literals
{ {
context * ctx = nullptr;
std::string name() const { return "gpu::write_literals"; } std::string name() const { return "gpu::write_literals"; }
void apply(program& p) const; void apply(program& p) const;
......
...@@ -27,7 +27,7 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const ...@@ -27,7 +27,7 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
eliminate_workspace{}, eliminate_workspace{},
eliminate_contiguous{}, eliminate_contiguous{},
dead_code_elimination{}, dead_code_elimination{},
write_literals{}, write_literals{&ctx},
eliminate_allocation{}, eliminate_allocation{},
check_context<context>{}, check_context<context>{},
dead_code_elimination{} dead_code_elimination{}
......
...@@ -7,15 +7,33 @@ namespace migraph { ...@@ -7,15 +7,33 @@ namespace migraph {
namespace gpu { namespace gpu {
struct hip_load_literal
{
shape s;
std::size_t n = 0;
std::string name() const { return "hip::load_literal"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs}.has(0);
return s;
}
argument compute(context& ctx, const shape&, const std::vector<argument>&) const
{
return ctx.literals.at(n);
}
};
void write_literals::apply(program& p) const void write_literals::apply(program& p) const
{ {
assert(ctx != nullptr);
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(ins->op.name() == "@literal") if(ins->op.name() == "@literal")
{ {
literal l = ins->lit; argument a = to_gpu(ins->lit.get_argument());
auto pre = p.add_literal(l); std::size_t n = ctx->literals.size();
p.replace_instruction(ins, hip_write{}, pre); ctx->literals.push_back(a);
p.replace_instruction(ins, hip_load_literal{a.get_shape(), n});
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment