write_literals.cpp 1.7 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
#include <migraphx/gpu/write_literals.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/env.hpp>
Paul's avatar
Paul committed
6

Paul's avatar
Paul committed
7
namespace migraphx {
Paul's avatar
Paul committed
8
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
9
namespace gpu {
Paul's avatar
Paul committed
10

Paul's avatar
Paul committed
11
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_COPY_LITERALS)
12

Paul's avatar
Paul committed
13
14
15
16
struct hip_load_literal
{
    shape s;
    std::size_t n = 0;
17
18
19
20
21
22
23

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.s, "shape"), f(self.n, "id"));
    }

Paul's avatar
Paul committed
24
25
26
27
28
29
30
31
32
33
34
35
    std::string name() const { return "hip::load_literal"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs}.has(0);
        return s;
    }
    argument compute(context& ctx, const shape&, const std::vector<argument>&) const
    {
        return ctx.literals.at(n);
    }
};

Paul's avatar
Paul committed
36
void write_literals::apply(program& p) const
Paul's avatar
Paul committed
37
{
Paul's avatar
Paul committed
38
    assert(ctx != nullptr);
Paul's avatar
Paul committed
39
40
    for(auto ins : iterator_for(p))
    {
Paul's avatar
Paul committed
41
        if(ins->name() == "@literal")
Paul's avatar
Paul committed
42
        {
Paul's avatar
Paul committed
43
            if(enabled(MIGRAPHX_COPY_LITERALS{}))
44
            {
Paul's avatar
Paul committed
45
46
                literal l  = ins->get_literal();
                auto pre   = p.add_literal(l);
47
                auto alloc = p.insert_instruction(std::next(pre), hip_allocate{l.get_shape()});
48
                p.replace_instruction(ins, hip_copy_to_gpu{}, pre, alloc);
Paul's avatar
Paul committed
49
            }
50
51
52
53
54
55
56
            else
            {
                argument a    = to_gpu(ins->get_literal().get_argument());
                std::size_t n = ctx->literals.size();
                ctx->literals.push_back(a);
                p.replace_instruction(ins, hip_load_literal{a.get_shape(), n});
            }
Paul's avatar
Paul committed
57
58
59
        }
    }
}
60

Paul's avatar
Paul committed
61
} // namespace gpu
Paul's avatar
Paul committed
62
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
63
} // namespace migraphx