"src/targets/gpu/vscode:/vscode.git/clone" did not exist on "ef5e7ce04d71ccf63a6267b0df107ebec20cb549"
fuse_ops.cpp 999 Bytes
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#include <migraph/gpu/fuse_ops.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/gpu/device/add_relu.hpp>
#include <migraph/instruction.hpp>

namespace migraph {

namespace gpu {

struct hip_add_relu
{
    std::string name() const { return "hip::add_relu"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs}.has(3).standard();
        return inputs.front();
    }
Paul's avatar
Paul committed
18
    argument compute(context&, const shape&, const std::vector<argument>& args) const
Paul's avatar
Paul committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    {
        device::add_relu(args.at(0), args.at(1), args.at(2));
        return args.at(2);
    }
};

void fuse_ops::apply(program& p) const
{
    for(auto ins : iterator_for(p))
    {
        if(ins->op.name() != "gpu::relu")
            continue;
        auto add_ins = ins->arguments.front();
        if(add_ins->op.name() != "gpu::add")
            continue;
        p.replace_instruction(ins, hip_add_relu{}, add_ins->arguments);
    }
}

} // namespace gpu

} // namespace migraph