misc_kernels.hip 12.7 KB
Newer Older
fengzch-das's avatar
fengzch-das committed
1
#include "hip/hip_runtime.h"
Zhekai Zhang's avatar
Zhekai Zhang committed
2
3
4
5
#include "misc_kernels_impl.cuh"
#include "misc_kernels.h"
#include "dispatch_utils.h"

muyangli's avatar
muyangli committed
6
7
namespace nunchaku::kernels {

Zhekai Zhang's avatar
Zhekai Zhang committed
8
9
10
11
12
13
14
Tensor add(Tensor a, Tensor b) {
    assert(a.shape.dataExtent == b.shape.dataExtent);
    assert(a.dtype() == b.dtype());
    assert(a.is_contiguous());
    assert(b.is_contiguous());

    int threadsPerBlock = 1024;
Muyang Li's avatar
Muyang Li committed
15
    int blocksPerGrid   = (a.numel() + threadsPerBlock - 1) / threadsPerBlock;
Zhekai Zhang's avatar
Zhekai Zhang committed
16

fengzch-das's avatar
fengzch-das committed
17
    auto stream = getCurrentHIPStreamMasqueradingAsCUDA();
Zhekai Zhang's avatar
Zhekai Zhang committed
18
19
20
21

    Tensor out = Tensor::empty_like(a);

    dispatch(out.scalar_type(), [&]<typename scalar_t>() {
fengzch-das's avatar
fengzch-das committed
22
       hipLaunchKernelGGL(( add_kernel), dim3(blocksPerGrid), dim3(threadsPerBlock), 0, stream, 
Zhekai Zhang's avatar
Zhekai Zhang committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
            a.data_ptr<scalar_t>(), b.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), out.numel());
    });

    return out;
}

void mul_add(Tensor x, Tensor scale, Tensor bias) {
    // assert(scale.shape.data == bias.shape.data);
    // FIXME FIXME
    assert(x.numel() % scale.numel() == 0);
    assert(x.numel() % bias.numel() == 0);
    assert(x.dtype() == scale.dtype());
    assert(x.dtype() == bias.dtype());

    constexpr int unroll = 8;

    assert((uintptr_t)x.data_ptr() % (x.scalar_size() * unroll) == 0);
muyangli's avatar
muyangli committed
40
    assert(!scale.valid() || (uintptr_t)scale.data_ptr() % (x.scalar_size() * unroll) == 0);
Zhekai Zhang's avatar
Zhekai Zhang committed
41
42
43
    assert((uintptr_t)bias.data_ptr() % (x.scalar_size() * unroll) == 0);

    assert(x.numel() % unroll == 0);
muyangli's avatar
muyangli committed
44
    assert(!scale.valid() || scale.numel() % unroll == 0);
Zhekai Zhang's avatar
Zhekai Zhang committed
45
46
47
    assert(bias.numel() % unroll == 0);

    int threadsPerBlock = 1024;
Muyang Li's avatar
Muyang Li committed
48
    int blocksPerGrid   = (x.numel() + threadsPerBlock * unroll - 1) / (threadsPerBlock * unroll);
Zhekai Zhang's avatar
Zhekai Zhang committed
49

fengzch-das's avatar
fengzch-das committed
50
    auto stream = getCurrentHIPStreamMasqueradingAsCUDA();
Zhekai Zhang's avatar
Zhekai Zhang committed
51
52

    dispatch(x.scalar_type(), [&]<typename scalar_t>() {
muyangli's avatar
muyangli committed
53
        if (scale.valid()) {
fengzch-das's avatar
fengzch-das committed
54
55
           hipLaunchKernelGGL(( mul_add_kernel<scalar_t, unroll, false>)
                , dim3(blocksPerGrid), dim3(threadsPerBlock), 0, stream, x.data_ptr<scalar_t>(),
Muyang Li's avatar
Muyang Li committed
56
57
58
59
60
61
62
63
64
                                                                scale.data_ptr<scalar_t>(),
                                                                bias.data_ptr<scalar_t>(),
                                                                0,
                                                                x.numel(),
                                                                scale.numel(),
                                                                bias.numel(),
                                                                0,
                                                                0,
                                                                0);
muyangli's avatar
muyangli committed
65
        } else {
fengzch-das's avatar
fengzch-das committed
66
           hipLaunchKernelGGL(( mul_add_kernel<scalar_t, unroll, true>), dim3(blocksPerGrid), dim3(threadsPerBlock), 0, stream, 
muyangli's avatar
muyangli committed
67
68
69
70
71
72
73
74
75
76
77
                x.data_ptr<scalar_t>(), nullptr, bias.data_ptr<scalar_t>(), 0, x.numel(), 1, bias.numel(), 0, 0, 0);
        }
    });
}

void mul_add_batch(Tensor x, Tensor scale, bool batch_scale, double scale_shift, Tensor bias, bool batch_bias) {

    const int batch_size = x.shape[0];
    assert(!batch_scale || scale.shape[0] == batch_size);
    assert(!batch_bias || bias.shape[0] == batch_size);

Muyang Li's avatar
Muyang Li committed
78
    const int numel       = x.numel() / batch_size;
muyangli's avatar
muyangli committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    const int numel_scale = scale.valid() ? (scale.numel() / (batch_scale ? batch_size : 1)) : 1;
    const int numel_bias  = bias.numel() / (batch_bias ? batch_size : 1);

    assert(numel % numel_scale == 0);
    assert(numel % numel_bias == 0);
    assert(!scale.valid() || x.dtype() == scale.dtype());
    assert(x.dtype() == bias.dtype());

    constexpr int unroll = 8;

    assert((uintptr_t)x.data_ptr() % (x.scalar_size() * unroll) == 0);
    assert(!scale.valid() || (uintptr_t)scale.data_ptr() % (x.scalar_size() * unroll) == 0);
    assert((uintptr_t)bias.data_ptr() % (x.scalar_size() * unroll) == 0);

    assert(numel % unroll == 0);
    assert(!scale.valid() || numel_scale % unroll == 0);
    assert(numel_bias % unroll == 0);

    int threadsPerBlock = 1024;
    dim3 grid(ceilDiv(numel, threadsPerBlock * unroll), batch_size);

fengzch-das's avatar
fengzch-das committed
100
    auto stream = getCurrentHIPStreamMasqueradingAsCUDA();
muyangli's avatar
muyangli committed
101
102
103

    dispatch(x.scalar_type(), [&]<typename scalar_t>() {
        if (scale.valid()) {
fengzch-das's avatar
fengzch-das committed
104
105
           hipLaunchKernelGGL(( mul_add_kernel<scalar_t, unroll, false>)
                , dim3(grid), dim3(threadsPerBlock), 0, stream, x.data_ptr<scalar_t>(),
Muyang Li's avatar
Muyang Li committed
106
107
108
109
110
111
112
113
114
                                                       scale.data_ptr<scalar_t>(),
                                                       bias.data_ptr<scalar_t>(),
                                                       (scalar_t)scale_shift,
                                                       numel,
                                                       numel_scale,
                                                       numel_bias,
                                                       x.stride(0),
                                                       batch_scale ? scale.stride(0) : 0,
                                                       batch_bias ? bias.stride(0) : 0);
muyangli's avatar
muyangli committed
115
        } else {
fengzch-das's avatar
fengzch-das committed
116
117
           hipLaunchKernelGGL(( mul_add_kernel<scalar_t, unroll, true>)
                , dim3(grid), dim3(threadsPerBlock), 0, stream, x.data_ptr<scalar_t>(),
Muyang Li's avatar
Muyang Li committed
118
119
120
121
122
123
124
125
126
                                                       nullptr,
                                                       bias.data_ptr<scalar_t>(),
                                                       (scalar_t)scale_shift,
                                                       numel,
                                                       1,
                                                       numel_bias,
                                                       x.stride(0),
                                                       0,
                                                       batch_bias ? bias.stride(0) : 0);
muyangli's avatar
muyangli committed
127
        }
Zhekai Zhang's avatar
Zhekai Zhang committed
128
129
130
131
132
133
134
135
136
137
    });
}

Tensor embedding(Tensor input_id, Tensor lookup) {
    assert(input_id.dtype() == Tensor::INT32);
    assert(lookup.ndims() == 2);

    auto shapeOut = input_id.shape;
    shapeOut.dataExtent.push_back(lookup.shape[-1]);

fengzch-das's avatar
fengzch-das committed
138
    auto stream = getCurrentHIPStreamMasqueradingAsCUDA();
Zhekai Zhang's avatar
Zhekai Zhang committed
139
140
141
142

    Tensor out = Tensor::empty(shapeOut, lookup.scalar_type(), input_id.device());

    dispatch(out.scalar_type(), [&]<typename scalar_t>() {
fengzch-das's avatar
fengzch-das committed
143
       hipLaunchKernelGGL(( EmbeddingKernel), dim3(input_id.numel()), dim3(std::min(lookup.shape[-1], 1024)), 0, stream, 
Zhekai Zhang's avatar
Zhekai Zhang committed
144
145
146
147
148
149
150
151
152
            input_id.data_ptr<int32_t>(), out.data_ptr<scalar_t>(), lookup.data_ptr<scalar_t>(), lookup.shape[-1]);
    });

    return out;
}

Tensor argmax_sample(Tensor logits) {
    assert(logits.ndims() == 2);

fengzch-das's avatar
fengzch-das committed
153
    auto stream = getCurrentHIPStreamMasqueradingAsCUDA();
Zhekai Zhang's avatar
Zhekai Zhang committed
154
155
156
157

    Tensor out = Tensor::empty({logits.shape[0]}, Tensor::INT32, logits.device());

    dispatch(logits.scalar_type(), [&]<typename scalar_t>() {
fengzch-das's avatar
fengzch-das committed
158
       hipLaunchKernelGGL(( argmax_sample_kernel), dim3(logits.shape[0]), dim3(std::min(logits.shape[1], 1024)), 0, stream, 
Muyang Li's avatar
Muyang Li committed
159
            logits.data_ptr<scalar_t>(), out.data_ptr<int32_t>(), logits.shape[1]);
Zhekai Zhang's avatar
Zhekai Zhang committed
160
161
162
163
164
165
166
167
168
169
170
    });

    return out;
}

void splitqkv(Tensor qkv, Tensor q, Tensor k, Tensor v) {
    // FIXME FIXME
    // assert(qkv.shape[0] == q.shape[0]);
    // assert(qkv.shape[0] == k.shape[0]);
    // assert(qkv.shape[0] == v.shape[0]);

fengzch-das's avatar
fengzch-das committed
171
    auto stream = getCurrentHIPStreamMasqueradingAsCUDA();
Zhekai Zhang's avatar
Zhekai Zhang committed
172
173
174
175
176
177
178

    int dim_q = q.shape[-1] * q.shape[-2];
    int dim_k = k.shape[-1] * k.shape[-2];
    int dim_v = v.shape[-1] * v.shape[-2];

    assert(dim_k == dim_v);
    assert(dim_q + dim_k + dim_v == qkv.shape[-1]);
Muyang Li's avatar
Muyang Li committed
179

Zhekai Zhang's avatar
Zhekai Zhang committed
180
181
182
    int num_tokens = qkv.numel() / qkv.shape[-1];

    dispatch(qkv.scalar_type(), [&]<typename scalar_t>() {
fengzch-das's avatar
fengzch-das committed
183
       hipLaunchKernelGGL(( splitqkv_kernel), dim3(num_tokens), dim3(std::min(qkv.shape[-1], 1024)), 0, stream, qkv.data_ptr<scalar_t>(),
Muyang Li's avatar
Muyang Li committed
184
185
186
187
188
                                                                                  q.data_ptr<scalar_t>(),
                                                                                  k.data_ptr<scalar_t>(),
                                                                                  v.data_ptr<scalar_t>(),
                                                                                  dim_q,
                                                                                  dim_k);
Zhekai Zhang's avatar
Zhekai Zhang committed
189
190
191
192
193
194
195
196
    });
}

template<size_t N>
std::array<Tensor, N> split_mod(Tensor input) {
    assert(input.shape[-1] % N == 0);

    int threadsPerBlock = 1024;
Muyang Li's avatar
Muyang Li committed
197
    int blocksPerGrid   = (input.numel() + threadsPerBlock - 1) / threadsPerBlock;
Zhekai Zhang's avatar
Zhekai Zhang committed
198

fengzch-das's avatar
fengzch-das committed
199
    auto stream = getCurrentHIPStreamMasqueradingAsCUDA();
Zhekai Zhang's avatar
Zhekai Zhang committed
200

201
    auto shapeOut = TensorShape(input.shape.dataExtent);
Zhekai Zhang's avatar
Zhekai Zhang committed
202
203
204
205
206
207
208
209
210
211
212
213
    shapeOut[-1] /= N;

    std::array<Tensor, N> out;
    for (int k = 0; k < N; k++) {
        out[k] = Tensor::empty(shapeOut, input.scalar_type(), input.device());
    }

    dispatch(input.scalar_type(), [&]<typename scalar_t>() {
        std::array<scalar_t *, N> outPtr;
        for (int k = 0; k < N; k++) {
            outPtr[k] = out[k].template data_ptr<scalar_t>();
        }
fengzch-das's avatar
fengzch-das committed
214
       hipLaunchKernelGGL(( split_mod_kernel), dim3(blocksPerGrid), dim3(threadsPerBlock), 0, stream, 
Muyang Li's avatar
Muyang Li committed
215
            input.data_ptr<scalar_t>(), outPtr, input.numel());
Zhekai Zhang's avatar
Zhekai Zhang committed
216
217
218
219
220
221
222
223
224
225
226
227
228
    });

    return out;
}

Tensor quant_static(Tensor x, float scale) {
    Tensor out = Tensor::empty(x.shape, Tensor::INT8, x.device());

    constexpr int unroll = 8;

    assert((uintptr_t)x.data_ptr() % (x.scalar_size() * unroll) == 0);

    int threadsPerBlock = 1024;
Muyang Li's avatar
Muyang Li committed
229
    int blocksPerGrid   = (x.numel() + threadsPerBlock * unroll - 1) / (threadsPerBlock * unroll);
Zhekai Zhang's avatar
Zhekai Zhang committed
230

fengzch-das's avatar
fengzch-das committed
231
    auto stream = getCurrentHIPStreamMasqueradingAsCUDA();
Zhekai Zhang's avatar
Zhekai Zhang committed
232
233

    dispatch(x.scalar_type(), [&]<typename scalar_t>() {
fengzch-das's avatar
fengzch-das committed
234
       hipLaunchKernelGGL(( quant_kernel_static<scalar_t, unroll>), dim3(blocksPerGrid), dim3(threadsPerBlock), 0, stream, 
Zhekai Zhang's avatar
Zhekai Zhang committed
235
236
237
238
239
240
241
242
243
244
245
246
247
248
            x.data_ptr<scalar_t>(), out.data_ptr<int8_t>(), (scalar_t)scale, x.numel());
    });

    return out;
}

Tensor quant_static_fuse_gelu(Tensor x, float scale) {
    Tensor out = Tensor::empty(x.shape, Tensor::INT8, x.device());

    constexpr int unroll = 8;

    assert((uintptr_t)x.data_ptr() % (x.scalar_size() * unroll) == 0);

    int threadsPerBlock = 1024;
Muyang Li's avatar
Muyang Li committed
249
    int blocksPerGrid   = (x.numel() + threadsPerBlock * unroll - 1) / (threadsPerBlock * unroll);
Zhekai Zhang's avatar
Zhekai Zhang committed
250

fengzch-das's avatar
fengzch-das committed
251
    auto stream = getCurrentHIPStreamMasqueradingAsCUDA();
Zhekai Zhang's avatar
Zhekai Zhang committed
252
253

    dispatch(x.scalar_type(), [&]<typename scalar_t>() {
fengzch-das's avatar
fengzch-das committed
254
       hipLaunchKernelGGL(( quant_kernel_static_fuse_gelu<scalar_t, unroll>), dim3(blocksPerGrid), dim3(threadsPerBlock), 0, stream, 
Zhekai Zhang's avatar
Zhekai Zhang committed
255
256
257
258
259
260
            x.data_ptr<scalar_t>(), out.data_ptr<int8_t>(), (scalar_t)scale, x.numel());
    });

    return out;
}

261
262
263
264
265
void cast(Tensor input, Tensor output) {
    assert(input.is_contiguous());
    assert(output.is_contiguous());
    assert(input.shape.dataExtent == output.shape.dataExtent);

266
267
268
269
    if (input.data_ptr() == output.data_ptr()) {
        assert(input.scalar_size() == output.scalar_size());
    }

fengzch-das's avatar
fengzch-das committed
270
    auto stream = getCurrentHIPStreamMasqueradingAsCUDA();
271
272
273
274
275
276

    dispatch(input.scalar_type(), [&]<typename input_t>() {
        dispatch(output.scalar_type(), [&]<typename output_t>() {
            constexpr int unroll = 16 / std::max(sizeof(input_t), sizeof(output_t));

            int threadsPerBlock = 1024;
Muyang Li's avatar
Muyang Li committed
277
            int blocksPerGrid   = (int)ceilDiv<int64_t>(input.numel(), threadsPerBlock * unroll);
278

fengzch-das's avatar
fengzch-das committed
279
           hipLaunchKernelGGL(( cast_kernel<input_t, output_t, unroll>), dim3(blocksPerGrid), dim3(threadsPerBlock), 0, stream, 
280
281
                input.data_ptr<input_t>(), output.data_ptr<output_t>(), input.numel());

fengzch-das's avatar
fengzch-das committed
282
            checkCUDA(hipGetLastError());
283
284
285
286
        });
    });
}

Zhekai Zhang's avatar
Zhekai Zhang committed
287
288
289
Tensor topk(Tensor x, int k) {
    constexpr int MAXK = 64 + 4;

Muyang Li's avatar
Muyang Li committed
290
    const int N     = x.shape[-1];
Zhekai Zhang's avatar
Zhekai Zhang committed
291
292
293
294
295
    const int batch = x.numel() / N;

    assert(k <= N);
    assert(k <= MAXK);

muyangli's avatar
muyangli committed
296
    auto outShape = TensorShape(x.shape.dataExtent);
Muyang Li's avatar
Muyang Li committed
297
    outShape[-1]  = k;
Zhekai Zhang's avatar
Zhekai Zhang committed
298
299
300
301
    outShape.dataStride.clear();

    Tensor out = Tensor::empty(outShape, Tensor::INT32, x.device());

fengzch-das's avatar
fengzch-das committed
302
    auto stream = getCurrentHIPStreamMasqueradingAsCUDA();
Zhekai Zhang's avatar
Zhekai Zhang committed
303
304
305
306
307
308
309
310

    dispatchVal(k, std::make_integer_sequence<int, MAXK + 1>(), [&]<int K>() {
        if constexpr (K == 0) {
            assert(false);
            return;
        }
        if constexpr (K > 0) {
            dispatch(x.scalar_type(), [&]<typename scalar_t>() {
fengzch-das's avatar
fengzch-das committed
311
               hipLaunchKernelGGL(( topk_kernel<scalar_t, K>), dim3(ceilDiv(batch, 32)), dim3(32), 0, stream, 
Muyang Li's avatar
Muyang Li committed
312
                    x.data_ptr<scalar_t>(), out.data_ptr<int>(), N, x.stride(-2), batch);
fengzch-das's avatar
fengzch-das committed
313
                checkCUDA(hipGetLastError());
Zhekai Zhang's avatar
Zhekai Zhang committed
314
315
316
317
318
319
320
321
322
323
324
            });
        }
    });

    return out;
}

template std::array<Tensor, 2> split_mod<2>(Tensor input);
template std::array<Tensor, 3> split_mod<3>(Tensor input);
template std::array<Tensor, 4> split_mod<4>(Tensor input);
template std::array<Tensor, 5> split_mod<5>(Tensor input);
muyangli's avatar
muyangli committed
325
326
template std::array<Tensor, 6> split_mod<6>(Tensor input);

Muyang Li's avatar
Muyang Li committed
327
}; // namespace nunchaku::kernels