encode_rotary_qk.cu 8.18 KB
Newer Older
yuguo-Jack's avatar
yuguo-Jack committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// 
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// 
//     http://www.apache.org/licenses/LICENSE-2.0
// 
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"

template <typename T>
__global__ void NeoXRotaryKernel(const T *input,
                                 const float *cos_emb,
                                 const float *sin_emb,
                                 const int *sequence_lengths,
                                 T *output,
                                 const int rotary_emb_dims,
                                 const int batch_size,
                                 const int head_num,
                                 const int seq_len,
                                 const int last_dim) {
  int bi = blockIdx.x;
  int hi = blockIdx.y;
  int si = blockIdx.z;
  if (sequence_lengths && si >= sequence_lengths[bi] * rotary_emb_dims) return;
  int half_lastdim = last_dim / 2;
  for (int ti = threadIdx.x; ti < half_lastdim; ti += blockDim.x) {
    int base_idx = bi * head_num * seq_len * last_dim +
                   hi * seq_len * last_dim + si * last_dim;
    int left_idx = base_idx + ti;
    const int right_idx = base_idx + ti + half_lastdim;
    int emb_idx_left = bi * seq_len * last_dim + si * last_dim + ti;
    int emb_idx_right =
        bi * seq_len * last_dim + si * last_dim + ti + half_lastdim;
    float input_left = static_cast<float>(input[left_idx]);
    float input_right = static_cast<float>(input[right_idx]);

    float cos_tmp_left = cos_emb[emb_idx_left];
    float sin_tmp_left = sin_emb[emb_idx_left];
    float cos_tmp_right = cos_emb[emb_idx_right];
    float sin_tmp_right = sin_emb[emb_idx_right];

    T res1 =
        static_cast<T>(input_left * cos_tmp_left - input_right * sin_tmp_left);
    T res2 = static_cast<T>(input_right * cos_tmp_right +
                            input_left * sin_tmp_right);
    output[left_idx] = res1;
    output[right_idx] = res2;
  }
}


template <typename T>
__global__ void RotaryKernel(const T *input,
                             const float *cos_emb,
                             const float *sin_emb,
                             const int *sequence_lengths,
                             T *output,
                             const int rotary_emb_dims,
                             const int batch_size,
                             const int head_num,
                             const int seq_len,
                             const int last_dim) {
  int bi = blockIdx.x;
  int hi = blockIdx.y;
  int si = blockIdx.z;
  if (sequence_lengths && si >= sequence_lengths[bi] * rotary_emb_dims) return;
  int half_lastdim = last_dim / 2;
  // Note(ZhenyuLi): Calculate the relevant data at one time, so that no
  // additional space is required.
  for (int ti = threadIdx.x; ti < half_lastdim; ti += blockDim.x) {
    int base_idx = bi * head_num * seq_len * last_dim +
                   hi * seq_len * last_dim + si * last_dim;
    int left_idx = base_idx + 2 * ti;
    const int right_idx = base_idx + 2 * ti + 1;
    int emb_idx = bi * seq_len * last_dim + si * last_dim + 2 * ti;
    float input_left = static_cast<float>(input[left_idx]);
    float input_right = static_cast<float>(input[right_idx]);
    float cos_tmp = cos_emb[emb_idx];
    float sin_tmp = sin_emb[emb_idx];
    T res1 = static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
    T res2 = static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
    output[left_idx] = res1;
    output[right_idx] = res2;
  }
}

template <paddle::DataType D>
void LaunchRotaryQK(const paddle::Tensor& q, 
                    const paddle::Tensor& kv, 
                    const paddle::Tensor& rotary_emb, 
                    const paddle::Tensor& seq_lens, 
                    const int32_t rotary_emb_dims, 
                    bool use_neox) {
    typedef PDTraits<D> traits_;
    typedef typename traits_::DataType DataType_;
    typedef typename traits_::data_t data_t;


    const int32_t batch_size = q.shape()[0];
    const int32_t head_num = q.shape()[1];
    const int32_t seq_len = q.shape()[2];
    const int32_t dim_head = q.shape()[3];

    auto cu_stream = q.stream();
    dim3 grid(batch_size, head_num, seq_len * rotary_emb_dims);
    const int last_dim = dim_head / rotary_emb_dims;
    auto getBlockSize = [](int dim) {
        if (dim > 256) {
        return 512;
        } else if (dim > 128) {
        return 256;
        } else if (dim > 64) {
        return 128;
        } else if (dim > 32) {
        return 64;
        } else {
        return 32;
        }
    };
    int BlockSize = getBlockSize(last_dim / 2);
    const float *cos_emb = rotary_emb.data<float>();
    const float *sin_emb = rotary_emb.data<float>() + batch_size * seq_len * dim_head;
    
    const DataType_* q_data = reinterpret_cast<const DataType_*>(q.data<data_t>()); 
    const DataType_* k_data = reinterpret_cast<const DataType_*>(kv.data<data_t>()); 

    DataType_* q_out_data = reinterpret_cast<DataType_*>(const_cast<data_t*>(q.data<data_t>())); 
    DataType_* k_out_data = reinterpret_cast<DataType_*>(const_cast<data_t*>(kv.data<data_t>())); 


    if (!use_neox) {
        RotaryKernel<<<grid, BlockSize, 0, cu_stream>>>(
            q_data,
            cos_emb,
            sin_emb,
            seq_lens.data<int>()/*sequence_lengths*/,
            q_out_data,
            rotary_emb_dims,
            batch_size,
            head_num,
            seq_len * rotary_emb_dims,
            last_dim);
        RotaryKernel<<<grid, BlockSize, 0, cu_stream>>>(
            k_data,
            cos_emb,
            sin_emb,
            seq_lens.data<int>()/*sequence_lengths*/,
            k_out_data,
            rotary_emb_dims,
            batch_size,
            head_num,
            seq_len * rotary_emb_dims,
            last_dim);
    } else {
        NeoXRotaryKernel<<<grid, BlockSize, 0, cu_stream>>>(
            q_data,
            cos_emb,
            sin_emb,
            seq_lens.data<int>()/*sequence_lengths*/,
            q_out_data,
            rotary_emb_dims,
            batch_size,
            head_num,
            seq_len * rotary_emb_dims,
            last_dim);
        NeoXRotaryKernel<<<grid, BlockSize, 0, cu_stream>>>(
            k_data,
            cos_emb,
            sin_emb,
            seq_lens.data<int>()/*sequence_lengths*/,
            k_out_data,
            rotary_emb_dims,
            batch_size,
            head_num,
            seq_len * rotary_emb_dims,
            last_dim);
    }
}

void RotaryQK(const paddle::Tensor& q, 
              const paddle::Tensor& kv, 
              const paddle::Tensor& rotary_emb, 
              const paddle::Tensor& seq_lens,
              const int32_t rotary_emb_dims, 
              bool use_neox) {
    switch (q.type()) {
        case paddle::DataType::BFLOAT16: {
            return LaunchRotaryQK<paddle::DataType::BFLOAT16>(
                q, kv, rotary_emb, seq_lens, rotary_emb_dims, use_neox
            );
        }
        case paddle::DataType::FLOAT16: {
            return LaunchRotaryQK<paddle::DataType::FLOAT16>(
                q, kv, rotary_emb, seq_lens, rotary_emb_dims, use_neox
            );
        }
        case paddle::DataType::FLOAT32: {
            return LaunchRotaryQK<paddle::DataType::FLOAT32>(
                q, kv, rotary_emb, seq_lens, rotary_emb_dims, use_neox
            );
        }
        default: {
            PD_THROW(
                "NOT supported data type. "
                "Only bfloat16, float16 and float32 are supported. ");
            break;
        }
    }
}



PD_BUILD_OP(encode_rotary_qk)
    .Inputs({"q", "kv", "rotary_emb", "seq_lens"})
    .Outputs({"rotary_q_out", "rotary_kv_out"})
    .SetInplaceMap({{"q", "rotary_q_out"}, {"kv", "rotary_kv_out"}})
    .Attrs({"rotary_emb_dims: int", "use_neox: bool"})
    .SetKernelFn(PD_KERNEL(RotaryQK));