write_literals.cpp 1.17 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraph/gpu/write_literals.hpp>
Paul's avatar
Paul committed
2
#include <migraph/iterator_for.hpp>
Paul's avatar
Paul committed
3
#include <migraph/gpu/hip.hpp>
Paul's avatar
Paul committed
4
5
6
7
#include <migraph/instruction.hpp>

namespace migraph {

Paul's avatar
Paul committed
8
namespace gpu {
Paul's avatar
Paul committed
9

Paul's avatar
Paul committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
struct hip_load_literal
{
    shape s;
    std::size_t n = 0;
    std::string name() const { return "hip::load_literal"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs}.has(0);
        return s;
    }
    argument compute(context& ctx, const shape&, const std::vector<argument>&) const
    {
        return ctx.literals.at(n);
    }
};

Paul's avatar
Paul committed
26
void write_literals::apply(program& p) const
Paul's avatar
Paul committed
27
{
Paul's avatar
Paul committed
28
    assert(ctx != nullptr);
Paul's avatar
Paul committed
29
30
    for(auto ins : iterator_for(p))
    {
mei-ye's avatar
mei-ye committed
31
#if 0        
Paul's avatar
Paul committed
32
33
        if(ins->op.name() == "@literal")
        {
Paul's avatar
Paul committed
34
            argument a    = to_gpu(ins->lit.get_argument());
Paul's avatar
Paul committed
35
36
37
            std::size_t n = ctx->literals.size();
            ctx->literals.push_back(a);
            p.replace_instruction(ins, hip_load_literal{a.get_shape(), n});
Paul's avatar
Paul committed
38
        }
mei-ye's avatar
mei-ye committed
39
40
41
42
43
#else
        if (ins->op.name() == "write_literal") {
            p.replace_instruction(ins, hip_memcpy{}, ins->arguments);
        }
#endif        
Paul's avatar
Paul committed
44
45
46
    }
}

Paul's avatar
Paul committed
47
} // namespace gpu
Paul's avatar
Paul committed
48
49

} // namespace migraph