Commit 211bfd05 authored by Paul's avatar Paul
Browse files

Copy literals during compile

parent a398a5c1
......@@ -12,6 +12,7 @@ struct context
{
shared<miopen_handle> handle;
shared<rocblas_handle_ptr> rbhandle;
std::vector<argument> literals{};
void finish() const { gpu_sync(); }
};
......
......@@ -2,6 +2,7 @@
#define MIGRAPH_GUARD_RTGLIB_MIOPEN_WRITE_LITERALS_HPP
#include <migraph/program.hpp>
#include <migraph/gpu/context.hpp>
namespace migraph {
......@@ -9,6 +10,7 @@ namespace gpu {
struct write_literals
{
context * ctx = nullptr;
std::string name() const { return "gpu::write_literals"; }
void apply(program& p) const;
......
......@@ -27,7 +27,7 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
eliminate_workspace{},
eliminate_contiguous{},
dead_code_elimination{},
write_literals{},
write_literals{&ctx},
eliminate_allocation{},
check_context<context>{},
dead_code_elimination{}
......
......@@ -7,15 +7,33 @@ namespace migraph {
namespace gpu {
struct hip_load_literal
{
shape s;
std::size_t n = 0;
std::string name() const { return "hip::load_literal"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs}.has(0);
return s;
}
argument compute(context& ctx, const shape&, const std::vector<argument>&) const
{
return ctx.literals.at(n);
}
};
void write_literals::apply(program& p) const
{
assert(ctx != nullptr);
for(auto ins : iterator_for(p))
{
if(ins->op.name() == "@literal")
{
literal l = ins->lit;
auto pre = p.add_literal(l);
p.replace_instruction(ins, hip_write{}, pre);
argument a = to_gpu(ins->lit.get_argument());
std::size_t n = ctx->literals.size();
ctx->literals.push_back(a);
p.replace_instruction(ins, hip_load_literal{a.get_shape(), n});
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment