write_literals.cpp 1.12 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
#include <migraphx/gpu/write_literals.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/env.hpp>
Paul's avatar
Paul committed
6

Paul's avatar
Paul committed
7
namespace migraphx {
Paul's avatar
Paul committed
8
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
9
namespace gpu {
Paul's avatar
Paul committed
10

Paul's avatar
Paul committed
11
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_COPY_LITERALS)
12

Paul's avatar
Paul committed
13
void write_literals::apply(program& p) const
Paul's avatar
Paul committed
14
{
Paul's avatar
Paul committed
15
    assert(ctx != nullptr);
16
    std::size_t n = 0;
Paul's avatar
Paul committed
17
18
    for(auto ins : iterator_for(p))
    {
Paul's avatar
Paul committed
19
        if(ins->name() == "@literal")
Paul's avatar
Paul committed
20
        {
Paul's avatar
Paul committed
21
            if(enabled(MIGRAPHX_COPY_LITERALS{}))
22
            {
Paul's avatar
Paul committed
23
24
                literal l  = ins->get_literal();
                auto pre   = p.add_literal(l);
25
                auto alloc = p.insert_instruction(std::next(pre), hip_allocate{l.get_shape()});
26
                p.replace_instruction(ins, hip_copy_to_gpu{}, pre, alloc);
Paul's avatar
Paul committed
27
            }
28
29
            else
            {
30
                std::string id = "@literal:" + std::to_string(n);
31
                p.replace_instruction(ins, hip_copy_literal{ins->get_literal(), id});
32
                n++;
33
            }
Paul's avatar
Paul committed
34
35
36
        }
    }
}
37

Paul's avatar
Paul committed
38
} // namespace gpu
Paul's avatar
Paul committed
39
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
40
} // namespace migraphx