fuse_ops.cpp 1.21 KB
Newer Older
Paul's avatar
Paul committed
1
#include <migraph/gpu/fuse_ops.hpp>
Paul's avatar
Paul committed
2
#include <migraph/matcher.hpp>
Paul's avatar
Paul committed
3
4
5
6
7
8
9
10
11
12
13
14
#include <migraph/gpu/device/add_relu.hpp>
#include <migraph/instruction.hpp>

namespace migraph {

namespace gpu {

struct hip_add_relu
{
    std::string name() const { return "hip::add_relu"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
Paul's avatar
Paul committed
15
        check_shapes{inputs, *this}.has(3);
Paul's avatar
Paul committed
16
17
        return inputs.front();
    }
Paul's avatar
Paul committed
18
    argument compute(context&, const shape&, const std::vector<argument>& args) const
Paul's avatar
Paul committed
19
    {
20
        device::add_relu(args.at(2), args.at(0), args.at(1));
Paul's avatar
Paul committed
21
22
23
24
        return args.at(2);
    }
};

Paul's avatar
Paul committed
25
struct match_add_relu
Paul's avatar
Paul committed
26
{
Paul's avatar
Paul committed
27
28
29
30
    auto matcher() const
    {
        return match::name("gpu::relu")(match::args(match::name("gpu::add").bind("add")));
    }
Paul's avatar
Paul committed
31

Paul's avatar
Paul committed
32
33
    void apply(program& p, match::matcher_result r) const
    {
Paul's avatar
Paul committed
34
        auto add_ins = r.instructions["add"];
Paul's avatar
Paul committed
35
36
        auto ins     = r.result;
        auto args    = add_ins->inputs();
Paul's avatar
Paul committed
37
        // Use the allocation from the relu operator
Paul's avatar
Paul committed
38
        args.back() = ins->inputs().back();
Paul's avatar
Paul committed
39
        p.replace_instruction(ins, hip_add_relu{}, args);
Paul's avatar
Paul committed
40
    }
Paul's avatar
Paul committed
41
42
};

Paul's avatar
Paul committed
43
void fuse_ops::apply(program& p) const { match::find_matches(p, match_add_relu{}); }
Paul's avatar
Paul committed
44
45
46
47

} // namespace gpu

} // namespace migraph