rewrite_rnn.hpp 1.95 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_RNN_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_RNN_HPP

#include <string>
#include <vector>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/config.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct program;

/**
 * Rewrite rnn to gemm and add.
 */
struct rewrite_rnn
{
    std::string name() const { return "rewrite_rnn"; }
    void apply(program& prog) const;

    private:
Shucai Xiao's avatar
Shucai Xiao committed
24
25
26
    // for vanilla rnn operators
    void apply_vanilla_rnn(program& prog, instruction_ref ins) const;
    std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
27
28
29
30
31
32
33
34
                                                  program& prog,
                                                  instruction_ref ins,
                                                  instruction_ref input,
                                                  instruction_ref w,
                                                  instruction_ref r,
                                                  instruction_ref bias,
                                                  instruction_ref ih,
                                                  operation& actv_func) const;
Shucai Xiao's avatar
Shucai Xiao committed
35
    std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const;
36

37
38
39
40
41
42
43
44
45
46
47
    // for gru operators
    void apply_gru(program& prog, instruction_ref ins) const;
    std::vector<instruction_ref> gru_cell(bool is_forward,
                                          program& prog,
                                          instruction_ref ins,
                                          std::vector<instruction_ref> inputs,
                                          int linear_before_reset,
                                          const operation& actv_func1,
                                          const operation& actv_func2) const;

    std::vector<operation> gru_actv_funcs(instruction_ref ins) const;
Shucai Xiao's avatar
Shucai Xiao committed
48
49
50
51
52
53
};

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif