eliminate_workspace.cpp 1.15 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
#include <migraphx/gpu/eliminate_workspace.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/pass_config.hpp>
9

Paul's avatar
Paul committed
10
namespace migraphx {
Paul's avatar
Paul committed
11
inline namespace MIGRAPHX_INLINE_NS {
12
13
14
15
16
17
18
19
namespace gpu {

void eliminate_workspace::apply(program& p) const
{
    std::size_t n = 0;
    std::vector<instruction_ref> allocs;
    for(auto ins : iterator_for(p))
    {
Paul's avatar
Paul committed
20
        if(ins->outputs().size() != 1)
21
            continue;
Paul's avatar
Paul committed
22
        if(ins->name() != "hip::allocate")
23
            continue;
24
        auto&& a = any_cast<hip_allocate>(ins->get_operator());
Paul's avatar
Paul committed
25
        if(a.tag == "workspace")
26
27
28
29
30
        {
            n = std::max(n, ins->get_shape().bytes());
            allocs.push_back(ins);
        }
    }
Paul's avatar
Paul committed
31
    if(n > 0)
32
    {
Paul's avatar
Paul committed
33
34
35
36
37
38
        auto ws = p.add_parameter("workspace", shape{shape::int8_type, {n}});
        for(auto&& a : allocs)
        {
            p.replace_instruction(a, ws);
            p.remove_instruction(a);
        }
39
40
    }
}
41

42
} // namespace gpu
Paul's avatar
Paul committed
43
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
44
} // namespace migraphx