write_literals.cpp 1.04 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraph/gpu/write_literals.hpp>
Paul's avatar
Paul committed
2
#include <migraph/iterator_for.hpp>
Paul's avatar
Paul committed
3
#include <migraph/gpu/hip.hpp>
Paul's avatar
Paul committed
4
#include <migraph/instruction.hpp>
mei-ye's avatar
mei-ye committed
5
#include <migraph/pass_config.hpp>
Paul's avatar
Paul committed
6
7
8

namespace migraph {

Paul's avatar
Paul committed
9
namespace gpu {
Paul's avatar
Paul committed
10

Paul's avatar
Paul committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
struct hip_load_literal
{
    shape s;
    std::size_t n = 0;
    std::string name() const { return "hip::load_literal"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs}.has(0);
        return s;
    }
    argument compute(context& ctx, const shape&, const std::vector<argument>&) const
    {
        return ctx.literals.at(n);
    }
};

Paul's avatar
Paul committed
27
void write_literals::apply(program& p) const
Paul's avatar
Paul committed
28
{
Paul's avatar
Paul committed
29
    assert(ctx != nullptr);
Paul's avatar
Paul committed
30
31
32
33
    for(auto ins : iterator_for(p))
    {
        if(ins->op.name() == "@literal")
        {
Paul's avatar
Paul committed
34
            argument a    = to_gpu(ins->lit.get_argument());
Paul's avatar
Paul committed
35
36
37
            std::size_t n = ctx->literals.size();
            ctx->literals.push_back(a);
            p.replace_instruction(ins, hip_load_literal{a.get_shape(), n});
Paul's avatar
Paul committed
38
39
40
        }
    }
}
Paul's avatar
Paul committed
41
} // namespace gpu
Paul's avatar
Paul committed
42
} // namespace migraph