parse_attention.cpp 7.93 KB
Newer Older
turneram's avatar
turneram committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {

struct parse_attention : op_parser<parse_attention>
{
    std::vector<op_desc> operators() const { return {{"Attention"}}; }

    instruction_ref parse(const op_desc& /*opd*/,
                          const onnx_parser& parser,
                          onnx_parser::node_info info,
                          const std::vector<instruction_ref>& args) const
turneram's avatar
turneram committed
18
19
20
21
    {
        auto input      = args[0];
        auto weights    = args[1];
        auto bias       = args[2];
turneram's avatar
turneram committed
22
23
24
25
        auto mask_index = args[3];

        instruction_ref past;
        instruction_ref extra_add_qk;
turneram's avatar
turneram committed
26
        bool is_past         = false;
turneram's avatar
turneram committed
27
        bool is_extra_add_qk = false;
turneram's avatar
turneram committed
28
        if(args.size() > 4)
turneram's avatar
turneram committed
29
        {
turneram's avatar
turneram committed
30
            past    = args[4];
turneram's avatar
turneram committed
31
32
            is_past = true;
        }
turneram's avatar
turneram committed
33
        if(args.size() == 6)
turneram's avatar
turneram committed
34
35
        {
            is_extra_add_qk = true;
turneram's avatar
turneram committed
36
            extra_add_qk    = args[5];
turneram's avatar
turneram committed
37
38
39
40
41
42
43
44
        }

        // ORT default is 12
        std::size_t num_heads = 12;
        if(contains(info.attributes, "num_heads"))
            num_heads = info.attributes.at("num_heads").i();

        // input shape: (batch_size, sequence_length, input_hidden_size)
turneram's avatar
turneram committed
45
46
47
        auto input_lens        = input->get_shape().lens();
        auto batch_size        = input_lens.at(0);
        auto sequence_length   = input_lens.at(1);
turneram's avatar
turneram committed
48
49
50
        auto input_hidden_size = input_lens.at(2);

        // bias shape: (3 * hidden_size)
turneram's avatar
turneram committed
51
52
53
        auto bias_lens           = bias->get_shape().lens();
        auto hidden_size         = bias_lens.at(0) / 3;
        auto head_size           = hidden_size / num_heads;
turneram's avatar
turneram committed
54
55
56
57
58
        int past_sequence_length = 0;

        // GetPresent
        //    Input and output shapes:
        //      past        : (2, batch_size, num_heads, past_sequence_length, head_size)
turneram's avatar
turneram committed
59
60
        //      present     : (2, batch_size, num_heads, past_sequence_length + sequence_length,
        //      head_size)
turneram's avatar
turneram committed
61
62
        std::vector<std::size_t> present_lens{2, batch_size, num_heads, sequence_length, head_size};

turneram's avatar
turneram committed
63
        if(is_past)
turneram's avatar
turneram committed
64
        {
turneram's avatar
turneram committed
65
            auto past_lens       = past->get_shape().lens();
turneram's avatar
turneram committed
66
67
68
69
70
71
            past_sequence_length = past_lens.at(3);
            present_lens[3] += past_lens[3];
        }

        // Use GEMM for fully connection.
        auto m = batch_size * sequence_length;
turneram's avatar
turneram committed
72
        auto n = bias_lens.front();
turneram's avatar
turneram committed
73
74
75
76
77
78
        auto k = input_hidden_size;

        // Bias shape is (N), broadcast using B(N, M) = 1 * bias(N, 1) x ones(1, M) + 0 * B.
        auto bias_type = bias->get_shape().type();
        std::vector<float> ones_vec(m, 1);
        std::vector<std::size_t> ones_lens{1, m};
turneram's avatar
turneram committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        auto ones =
            info.add_literal(migraphx::literal{migraphx::shape{bias_type, ones_lens}, ones_vec});
        bias        = info.add_instruction(migraphx::make_op("reshape", {{"dims", {n, 1}}}), bias);
        auto gemm_1 = info.add_instruction(
            migraphx::make_op("dot"),
            bias,
            ones /* info.make_contiguous(mb_bias), info.make_contiguous(ones) */);
        gemm_1 =
            info.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), gemm_1);

        /// ORT: Gemm, note that ROCM assumes col-major, so result(N, M) = 1 * weights x input + 1 x
        /// B. Assume row-major => results(N, M) = 1 * input x weights + 1 x B ?
        auto input_sq = info.add_instruction(
            migraphx::make_op("reshape", {{"dims", {batch_size * sequence_length, hidden_size}}}),
            input);
        auto gemm_2    = info.add_instruction(migraphx::make_op("dot"), input_sq, weights);
turneram's avatar
turneram committed
95
96
97
98
99
        auto add_gemms = info.add_instruction(migraphx::make_op("add"), gemm_1, gemm_2);

        // LaunchAttentionKernel:
        //   LaunchTransQkv
        // input should be BxSx3xNxH => scratch3: 3xBxNxSxH
turneram's avatar
turneram committed
100
101
102
103
        add_gemms = info.add_instruction(
            migraphx::make_op("reshape",
                              {{"dims", {batch_size, sequence_length, 3, num_heads, head_size}}}),
            add_gemms);
turneram's avatar
turneram committed
104
        std::vector<std::size_t> qkv_perm{2, 0, 3, 1, 4};
turneram's avatar
turneram committed
105
106
        auto transqkv = info.add_instruction(
            migraphx::make_op("transpose", {{"permutation", qkv_perm}}), add_gemms);
turneram's avatar
turneram committed
107
108
109

        // now scratch3 has Q, K, V: each has size BxNxSxH
        // => transqkv has shape 3xBxNxSxH
turneram's avatar
turneram committed
110
        auto batches        = batch_size * num_heads;
turneram's avatar
turneram committed
111
        auto size_per_batch = sequence_length * head_size;
turneram's avatar
turneram committed
112
113
114
115
116
117
118
119
        auto total_size     = batches * size_per_batch;

        auto q_t = info.add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transqkv);
        auto k_t = info.add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), transqkv);
        auto v_t = info.add_instruction(
            migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), transqkv);
turneram's avatar
turneram committed
120
121
122
123
        q_t = info.add_instruction(make_op("squeeze", {{"axes", {0}}}), q_t);
        k_t = info.add_instruction(make_op("squeeze", {{"axes", {0}}}), k_t);
        v_t = info.add_instruction(make_op("squeeze", {{"axes", {0}}}), v_t);

turneram's avatar
turneram committed
124
        if(is_past)
turneram's avatar
turneram committed
125
126
        {
            k_t = info.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), past, k_t);
turneram's avatar
turneram committed
127
128
            v_t = info.add_instruction(
                migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {3}}}), k_t);
turneram's avatar
turneram committed
129
130
        }

turneram's avatar
turneram committed
131
132
133
        // Raw attention mask could be 2D (BxS) or 3D (BxSxS*) or 4D(Bx1xMxM), where M is the max
        // sequence length.
        auto mask_index_lens        = mask_index->get_shape().lens();
turneram's avatar
turneram committed
134
135
136
137
        bool use_raw_attention_mask = mask_index_lens.size() >= 2;

        // compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS*
        // Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS*
turneram's avatar
turneram committed
138
        const float rsqrt_head_size   = 1.f / sqrt(static_cast<float>(head_size));
turneram's avatar
turneram committed
139
        const int all_sequence_length = past_sequence_length + sequence_length;
turneram's avatar
turneram committed
140
        const int temp_matrix_size    = sequence_length * all_sequence_length;
turneram's avatar
turneram committed
141
142
143
144
145
146
147

        // For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation.
        const float alpha = use_raw_attention_mask ? 1.0 : rsqrt_head_size;

        // K{B,N,S,H} -> K'{B,N,H,S}
        k_t = info.add_instruction(make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), k_t);
        auto gemm3 = info.add_instruction(migraphx::make_op("dot"), q_t, k_t);
turneram's avatar
turneram committed
148
        if(is_extra_add_qk)
turneram's avatar
turneram committed
149
150
151
            gemm3 = info.add_instruction(make_op("add"), gemm3, extra_add_qk);
        auto alpha_lit = info.add_instruction(
            migraphx::make_op("multibroadcast", {{"out_lens", gemm3->get_shape().lens()}}),
turneram's avatar
turneram committed
152
153
154
155
            info.add_literal(
                migraphx::literal{migraphx::shape{gemm3->get_shape().type()}, {alpha}}));
        gemm3 =
            info.add_instruction(migraphx::make_op("mul"), gemm3, info.make_contiguous(alpha_lit));
turneram's avatar
turneram committed
156
157

        // apply softmax and store result P to scratch2: BxNxSxS*
turneram's avatar
turneram committed
158
159
160
        std::vector<float> mask(batch_size * num_heads * sequence_length * all_sequence_length, 0);
        if(false and mask_index_lens.size() >= 2) {}
        else if(false and mask_index_lens.size() == 1)
turneram's avatar
turneram committed
161
162
163
164
165
166
167
168
169
        {
        }
        // else => no mask
        auto softmax = info.add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), gemm3);

        // compute P*V (as V*P), and store in scratch3: BxNxSxH
        auto gemm4 = info.add_instruction(migraphx::make_op("dot"), softmax, v_t);

        // scratch3 is BxNxSxH, transpose to output BxSxNxH
turneram's avatar
turneram committed
170
171
172
173
174
        gemm4 = info.add_instruction(
            migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), gemm4);
        gemm4 = info.add_instruction(
            make_op("reshape", {{"dims", {batch_size, sequence_length, num_heads * head_size}}}),
            info.make_contiguous(gemm4));
turneram's avatar
turneram committed
175
176
177
178
179
180
181
        return gemm4;
    }
};

} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx