write_literals.cpp 1.17 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
#include <migraphx/gpu/write_literals.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/instruction.hpp>
5
#include <migraphx/program.hpp>
Paul's avatar
Paul committed
6
#include <migraphx/env.hpp>
Paul's avatar
Paul committed
7

Paul's avatar
Paul committed
8
namespace migraphx {
Paul's avatar
Paul committed
9
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
10
namespace gpu {
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_COPY_LITERALS)
13

14
void write_literals::apply(module& m) const
Paul's avatar
Paul committed
15
{
Paul's avatar
Paul committed
16
    assert(ctx != nullptr);
17
    std::size_t n = 0;
18
    for(auto ins : iterator_for(m))
Paul's avatar
Paul committed
19
    {
Paul's avatar
Paul committed
20
        if(ins->name() == "@literal")
Paul's avatar
Paul committed
21
        {
Paul's avatar
Paul committed
22
            if(enabled(MIGRAPHX_COPY_LITERALS{}))
23
            {
Paul's avatar
Paul committed
24
                literal l  = ins->get_literal();
25
26
27
                auto pre   = m.add_literal(l);
                auto alloc = m.insert_instruction(std::next(pre), hip_allocate{l.get_shape()});
                m.replace_instruction(ins, hip_copy_to_gpu{}, pre, alloc);
Paul's avatar
Paul committed
28
            }
29
30
            else
            {
31
32
                std::string id = m.name() + ":@literal:" + std::to_string(n);
                m.replace_instruction(ins, hip_copy_literal{ins->get_literal(), id});
33
                n++;
34
            }
Paul's avatar
Paul committed
35
36
37
        }
    }
}
38

Paul's avatar
Paul committed
39
} // namespace gpu
Paul's avatar
Paul committed
40
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
41
} // namespace migraphx