fuse_ops.cpp 1.08 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#include <migraph/gpu/fuse_ops.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/gpu/device/add_relu.hpp>
#include <migraph/instruction.hpp>

namespace migraph {

namespace gpu {

struct hip_add_relu
{
    std::string name() const { return "hip::add_relu"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
Paul's avatar
Paul committed
15
        check_shapes{inputs, *this}.has(3);
Paul's avatar
Paul committed
16
17
        return inputs.front();
    }
Paul's avatar
Paul committed
18
    argument compute(context&, const shape&, const std::vector<argument>& args) const
Paul's avatar
Paul committed
19
    {
20
        device::add_relu(args.at(2), args.at(0), args.at(1));
Paul's avatar
Paul committed
21
22
23
24
25
26
27
28
        return args.at(2);
    }
};

void fuse_ops::apply(program& p) const
{
    for(auto ins : iterator_for(p))
    {
Paul's avatar
Paul committed
29
        if(ins->name() != "gpu::relu")
Paul's avatar
Paul committed
30
            continue;
Paul's avatar
Paul committed
31
        auto add_ins = ins->inputs().front();
Paul's avatar
Paul committed
32
        if(add_ins->name() != "gpu::add")
Paul's avatar
Paul committed
33
            continue;
Paul's avatar
Paul committed
34
        auto args = add_ins->inputs();
Paul's avatar
Paul committed
35
        // Use the allocation from the relu operator
Paul's avatar
Paul committed
36
        args.back() = ins->inputs().back();
Paul's avatar
Paul committed
37
        p.replace_instruction(ins, hip_add_relu{}, args);
Paul's avatar
Paul committed
38
39
40
41
42
43
    }
}

} // namespace gpu

} // namespace migraph