fuse_ops.cpp 1.19 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraph/gpu/fuse_ops.hpp>
Paul's avatar
Paul committed
2
#include <migraph/matcher.hpp>
Paul's avatar
Paul committed
3
4
5
6
7
8
9
10
11
12
13
14
#include <migraph/gpu/device/add_relu.hpp>
#include <migraph/instruction.hpp>

namespace migraph {

namespace gpu {

struct hip_add_relu
{
    std::string name() const { return "hip::add_relu"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
Paul's avatar
Paul committed
15
        check_shapes{inputs, *this}.has(3);
Paul's avatar
Paul committed
16
17
        return inputs.front();
    }
Paul's avatar
Paul committed
18
    argument compute(context&, const shape&, const std::vector<argument>& args) const
Paul's avatar
Paul committed
19
    {
20
        device::add_relu(args.at(2), args.at(0), args.at(1));
Paul's avatar
Paul committed
21
22
23
24
        return args.at(2);
    }
};

Paul's avatar
Paul committed
25
struct match_add_relu
Paul's avatar
Paul committed
26
{
Paul's avatar
Paul committed
27
28
29
30
31
32
    auto matcher() const { return match::name("gpu::relu")(match::args(match::name("gpu::add").bind("add"))); }

    void apply(program& p, match::matcher_result r) const 
    { 
        auto add_ins = r.instructions["add"];
        auto ins = r.result;
Paul's avatar
Paul committed
33
        auto args = add_ins->inputs();
Paul's avatar
Paul committed
34
        // Use the allocation from the relu operator
Paul's avatar
Paul committed
35
        args.back() = ins->inputs().back();
Paul's avatar
Paul committed
36
        p.replace_instruction(ins, hip_add_relu{}, args);
Paul's avatar
Paul committed
37
    }
Paul's avatar
Paul committed
38
39
40
41
42
};

void fuse_ops::apply(program& p) const
{
    match::find_matches(p, match_add_relu{});
Paul's avatar
Paul committed
43
44
45
46
47
}

} // namespace gpu

} // namespace migraph