"configs/eval_qwen1.5_vllm.py" did not exist on "81d0e4d793103e7d607669002985f1278c261cfb"
rewrite_rnn.hpp 3.42 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
4
5
6
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_RNN_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_RNN_HPP

#include <string>
#include <vector>
#include <migraphx/instruction_ref.hpp>
7
#include <migraphx/operation.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
8
#include <migraphx/config.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
9
#include <migraphx/op/common.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
10
11
12
13

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

Shucai Xiao's avatar
Shucai Xiao committed
14
struct module;
Shucai Xiao's avatar
Shucai Xiao committed
15
16
17
18
19
20
21

/**
 * Rewrite rnn to gemm and add.
 */
struct rewrite_rnn
{
    std::string name() const { return "rewrite_rnn"; }
22
    void apply(module& m) const;
Shucai Xiao's avatar
Shucai Xiao committed
23
24

    private:
Shucai Xiao's avatar
Shucai Xiao committed
25
    // for vanilla rnn operators
26
    void apply_vanilla_rnn(module& m, instruction_ref ins) const;
Shucai Xiao's avatar
Shucai Xiao committed
27
    std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward,
28
                                                  module& m,
Shucai Xiao's avatar
Shucai Xiao committed
29
                                                  instruction_ref ins,
30
                                                  std::vector<instruction_ref> inputs,
Shucai Xiao's avatar
Shucai Xiao committed
31
                                                  operation& actv_func) const;
Shucai Xiao's avatar
Shucai Xiao committed
32
    std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const;
33

34
    // for gru operators
35
    void apply_gru(module& m, instruction_ref ins) const;
36
    std::vector<instruction_ref> gru_cell(bool is_forward,
37
                                          module& m,
38
39
40
41
42
43
44
                                          instruction_ref ins,
                                          std::vector<instruction_ref> inputs,
                                          int linear_before_reset,
                                          const operation& actv_func1,
                                          const operation& actv_func2) const;

    std::vector<operation> gru_actv_funcs(instruction_ref ins) const;
Shucai Xiao's avatar
Shucai Xiao committed
45
46

    // for lstm operators
47
    void apply_lstm(module& m, instruction_ref ins) const;
Shucai Xiao's avatar
Shucai Xiao committed
48
    std::vector<instruction_ref> lstm_cell(bool is_forward,
49
                                           module& m,
Shucai Xiao's avatar
Shucai Xiao committed
50
51
52
53
54
                                           instruction_ref ins,
                                           std::vector<instruction_ref> inputs,
                                           const operation& actv_func1,
                                           const operation& actv_func2,
                                           const operation& actv_func3) const;
Shucai Xiao's avatar
Shucai Xiao committed
55
56

    std::vector<operation> lstm_actv_funcs(instruction_ref ins) const;
Shucai Xiao's avatar
Shucai Xiao committed
57

58
59
    bool is_variable_seq_lens(const module& m, instruction_ref seq_lens) const;
    instruction_ref replace_last_hs_output(module& m,
Shucai Xiao's avatar
Shucai Xiao committed
60
61
62
63
64
                                           instruction_ref ins,
                                           instruction_ref seq_lens,
                                           instruction_ref last_hs_output,
                                           op::rnn_direction dirct) const;

65
    void replace_last_cell_output(module& m,
Shucai Xiao's avatar
Shucai Xiao committed
66
67
68
69
70
71
                                  instruction_ref ins,
                                  instruction_ref seq_lens,
                                  instruction_ref cell_outputs,
                                  instruction_ref last_cell_output,
                                  op::rnn_direction dirct) const;

72
    std::size_t get_seq_len(const module& m, instruction_ref input, instruction_ref seq_lens) const;
73

74
    instruction_ref pad_hidden_states(module& m,
75
76
77
                                      instruction_ref seq,
                                      instruction_ref seq_lens,
                                      instruction_ref hs) const;
Shucai Xiao's avatar
Shucai Xiao committed
78
79
80
81
82
83
};

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif