misc_kernels.cu 12.1 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
#include "misc_kernels_impl.cuh"
#include "misc_kernels.h"
#include "dispatch_utils.h"

muyangli's avatar
muyangli committed
5
6
namespace nunchaku::kernels {

Zhekai Zhang's avatar
Zhekai Zhang committed
7
8
9
10
11
12
13
Tensor add(Tensor a, Tensor b) {
    assert(a.shape.dataExtent == b.shape.dataExtent);
    assert(a.dtype() == b.dtype());
    assert(a.is_contiguous());
    assert(b.is_contiguous());

    int threadsPerBlock = 1024;
Muyang Li's avatar
Muyang Li committed
14
    int blocksPerGrid   = (a.numel() + threadsPerBlock - 1) / threadsPerBlock;
Zhekai Zhang's avatar
Zhekai Zhang committed
15

fengzch-das's avatar
fengzch-das committed
16
    auto stream = getCurrentCUDAStream();
Zhekai Zhang's avatar
Zhekai Zhang committed
17
18
19
20

    Tensor out = Tensor::empty_like(a);

    dispatch(out.scalar_type(), [&]<typename scalar_t>() {
fengzch-das's avatar
fengzch-das committed
21
        add_kernel<<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
Zhekai Zhang's avatar
Zhekai Zhang committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
            a.data_ptr<scalar_t>(), b.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), out.numel());
    });

    return out;
}

void mul_add(Tensor x, Tensor scale, Tensor bias) {
    // assert(scale.shape.data == bias.shape.data);
    // FIXME FIXME
    assert(x.numel() % scale.numel() == 0);
    assert(x.numel() % bias.numel() == 0);
    assert(x.dtype() == scale.dtype());
    assert(x.dtype() == bias.dtype());

    constexpr int unroll = 8;

    assert((uintptr_t)x.data_ptr() % (x.scalar_size() * unroll) == 0);
muyangli's avatar
muyangli committed
39
    assert(!scale.valid() || (uintptr_t)scale.data_ptr() % (x.scalar_size() * unroll) == 0);
Zhekai Zhang's avatar
Zhekai Zhang committed
40
41
42
    assert((uintptr_t)bias.data_ptr() % (x.scalar_size() * unroll) == 0);

    assert(x.numel() % unroll == 0);
muyangli's avatar
muyangli committed
43
    assert(!scale.valid() || scale.numel() % unroll == 0);
Zhekai Zhang's avatar
Zhekai Zhang committed
44
45
46
    assert(bias.numel() % unroll == 0);

    int threadsPerBlock = 1024;
Muyang Li's avatar
Muyang Li committed
47
    int blocksPerGrid   = (x.numel() + threadsPerBlock * unroll - 1) / (threadsPerBlock * unroll);
Zhekai Zhang's avatar
Zhekai Zhang committed
48

fengzch-das's avatar
fengzch-das committed
49
    auto stream = getCurrentCUDAStream();
Zhekai Zhang's avatar
Zhekai Zhang committed
50
51

    dispatch(x.scalar_type(), [&]<typename scalar_t>() {
muyangli's avatar
muyangli committed
52
        if (scale.valid()) {
fengzch-das's avatar
fengzch-das committed
53
54
            mul_add_kernel<scalar_t, unroll, false>
                <<<blocksPerGrid, threadsPerBlock, 0, stream>>>(x.data_ptr<scalar_t>(),
Muyang Li's avatar
Muyang Li committed
55
56
57
58
59
60
61
62
63
                                                                scale.data_ptr<scalar_t>(),
                                                                bias.data_ptr<scalar_t>(),
                                                                0,
                                                                x.numel(),
                                                                scale.numel(),
                                                                bias.numel(),
                                                                0,
                                                                0,
                                                                0);
muyangli's avatar
muyangli committed
64
        } else {
fengzch-das's avatar
fengzch-das committed
65
            mul_add_kernel<scalar_t, unroll, true><<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
muyangli's avatar
muyangli committed
66
67
68
69
70
71
72
73
74
75
76
                x.data_ptr<scalar_t>(), nullptr, bias.data_ptr<scalar_t>(), 0, x.numel(), 1, bias.numel(), 0, 0, 0);
        }
    });
}

void mul_add_batch(Tensor x, Tensor scale, bool batch_scale, double scale_shift, Tensor bias, bool batch_bias) {

    const int batch_size = x.shape[0];
    assert(!batch_scale || scale.shape[0] == batch_size);
    assert(!batch_bias || bias.shape[0] == batch_size);

Muyang Li's avatar
Muyang Li committed
77
    const int numel       = x.numel() / batch_size;
muyangli's avatar
muyangli committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    const int numel_scale = scale.valid() ? (scale.numel() / (batch_scale ? batch_size : 1)) : 1;
    const int numel_bias  = bias.numel() / (batch_bias ? batch_size : 1);

    assert(numel % numel_scale == 0);
    assert(numel % numel_bias == 0);
    assert(!scale.valid() || x.dtype() == scale.dtype());
    assert(x.dtype() == bias.dtype());

    constexpr int unroll = 8;

    assert((uintptr_t)x.data_ptr() % (x.scalar_size() * unroll) == 0);
    assert(!scale.valid() || (uintptr_t)scale.data_ptr() % (x.scalar_size() * unroll) == 0);
    assert((uintptr_t)bias.data_ptr() % (x.scalar_size() * unroll) == 0);

    assert(numel % unroll == 0);
    assert(!scale.valid() || numel_scale % unroll == 0);
    assert(numel_bias % unroll == 0);

    int threadsPerBlock = 1024;
    dim3 grid(ceilDiv(numel, threadsPerBlock * unroll), batch_size);

fengzch-das's avatar
fengzch-das committed
99
    auto stream = getCurrentCUDAStream();
muyangli's avatar
muyangli committed
100
101
102

    dispatch(x.scalar_type(), [&]<typename scalar_t>() {
        if (scale.valid()) {
fengzch-das's avatar
fengzch-das committed
103
104
            mul_add_kernel<scalar_t, unroll, false>
                <<<grid, threadsPerBlock, 0, stream>>>(x.data_ptr<scalar_t>(),
Muyang Li's avatar
Muyang Li committed
105
106
107
108
109
110
111
112
113
                                                       scale.data_ptr<scalar_t>(),
                                                       bias.data_ptr<scalar_t>(),
                                                       (scalar_t)scale_shift,
                                                       numel,
                                                       numel_scale,
                                                       numel_bias,
                                                       x.stride(0),
                                                       batch_scale ? scale.stride(0) : 0,
                                                       batch_bias ? bias.stride(0) : 0);
muyangli's avatar
muyangli committed
114
        } else {
fengzch-das's avatar
fengzch-das committed
115
116
            mul_add_kernel<scalar_t, unroll, true>
                <<<grid, threadsPerBlock, 0, stream>>>(x.data_ptr<scalar_t>(),
Muyang Li's avatar
Muyang Li committed
117
118
119
120
121
122
123
124
125
                                                       nullptr,
                                                       bias.data_ptr<scalar_t>(),
                                                       (scalar_t)scale_shift,
                                                       numel,
                                                       1,
                                                       numel_bias,
                                                       x.stride(0),
                                                       0,
                                                       batch_bias ? bias.stride(0) : 0);
muyangli's avatar
muyangli committed
126
        }
Zhekai Zhang's avatar
Zhekai Zhang committed
127
128
129
130
131
132
133
134
135
136
    });
}

Tensor embedding(Tensor input_id, Tensor lookup) {
    assert(input_id.dtype() == Tensor::INT32);
    assert(lookup.ndims() == 2);

    auto shapeOut = input_id.shape;
    shapeOut.dataExtent.push_back(lookup.shape[-1]);

fengzch-das's avatar
fengzch-das committed
137
    auto stream = getCurrentCUDAStream();
Zhekai Zhang's avatar
Zhekai Zhang committed
138
139
140
141

    Tensor out = Tensor::empty(shapeOut, lookup.scalar_type(), input_id.device());

    dispatch(out.scalar_type(), [&]<typename scalar_t>() {
fengzch-das's avatar
fengzch-das committed
142
        EmbeddingKernel<<<input_id.numel(), std::min(lookup.shape[-1], 1024), 0, stream>>>(
Zhekai Zhang's avatar
Zhekai Zhang committed
143
144
145
146
147
148
149
150
151
            input_id.data_ptr<int32_t>(), out.data_ptr<scalar_t>(), lookup.data_ptr<scalar_t>(), lookup.shape[-1]);
    });

    return out;
}

Tensor argmax_sample(Tensor logits) {
    assert(logits.ndims() == 2);

fengzch-das's avatar
fengzch-das committed
152
    auto stream = getCurrentCUDAStream();
Zhekai Zhang's avatar
Zhekai Zhang committed
153
154
155
156

    Tensor out = Tensor::empty({logits.shape[0]}, Tensor::INT32, logits.device());

    dispatch(logits.scalar_type(), [&]<typename scalar_t>() {
fengzch-das's avatar
fengzch-das committed
157
        argmax_sample_kernel<<<logits.shape[0], std::min(logits.shape[1], 1024), 0, stream>>>(
Muyang Li's avatar
Muyang Li committed
158
            logits.data_ptr<scalar_t>(), out.data_ptr<int32_t>(), logits.shape[1]);
Zhekai Zhang's avatar
Zhekai Zhang committed
159
160
161
162
163
164
165
166
167
168
169
    });

    return out;
}

void splitqkv(Tensor qkv, Tensor q, Tensor k, Tensor v) {
    // FIXME FIXME
    // assert(qkv.shape[0] == q.shape[0]);
    // assert(qkv.shape[0] == k.shape[0]);
    // assert(qkv.shape[0] == v.shape[0]);

fengzch-das's avatar
fengzch-das committed
170
    auto stream = getCurrentCUDAStream();
Zhekai Zhang's avatar
Zhekai Zhang committed
171
172
173
174
175
176
177

    int dim_q = q.shape[-1] * q.shape[-2];
    int dim_k = k.shape[-1] * k.shape[-2];
    int dim_v = v.shape[-1] * v.shape[-2];

    assert(dim_k == dim_v);
    assert(dim_q + dim_k + dim_v == qkv.shape[-1]);
Muyang Li's avatar
Muyang Li committed
178

Zhekai Zhang's avatar
Zhekai Zhang committed
179
180
181
    int num_tokens = qkv.numel() / qkv.shape[-1];

    dispatch(qkv.scalar_type(), [&]<typename scalar_t>() {
fengzch-das's avatar
fengzch-das committed
182
        splitqkv_kernel<<<num_tokens, std::min(qkv.shape[-1], 1024), 0, stream>>>(qkv.data_ptr<scalar_t>(),
Muyang Li's avatar
Muyang Li committed
183
184
185
186
187
                                                                                  q.data_ptr<scalar_t>(),
                                                                                  k.data_ptr<scalar_t>(),
                                                                                  v.data_ptr<scalar_t>(),
                                                                                  dim_q,
                                                                                  dim_k);
Zhekai Zhang's avatar
Zhekai Zhang committed
188
189
190
191
192
193
194
195
    });
}

template<size_t N>
std::array<Tensor, N> split_mod(Tensor input) {
    assert(input.shape[-1] % N == 0);

    int threadsPerBlock = 1024;
Muyang Li's avatar
Muyang Li committed
196
    int blocksPerGrid   = (input.numel() + threadsPerBlock - 1) / threadsPerBlock;
Zhekai Zhang's avatar
Zhekai Zhang committed
197

fengzch-das's avatar
fengzch-das committed
198
    auto stream = getCurrentCUDAStream();
Zhekai Zhang's avatar
Zhekai Zhang committed
199

200
    auto shapeOut = TensorShape(input.shape.dataExtent);
Zhekai Zhang's avatar
Zhekai Zhang committed
201
202
203
204
205
206
207
208
209
210
211
212
    shapeOut[-1] /= N;

    std::array<Tensor, N> out;
    for (int k = 0; k < N; k++) {
        out[k] = Tensor::empty(shapeOut, input.scalar_type(), input.device());
    }

    dispatch(input.scalar_type(), [&]<typename scalar_t>() {
        std::array<scalar_t *, N> outPtr;
        for (int k = 0; k < N; k++) {
            outPtr[k] = out[k].template data_ptr<scalar_t>();
        }
fengzch-das's avatar
fengzch-das committed
213
        split_mod_kernel<<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
Muyang Li's avatar
Muyang Li committed
214
            input.data_ptr<scalar_t>(), outPtr, input.numel());
Zhekai Zhang's avatar
Zhekai Zhang committed
215
216
217
218
219
220
221
222
223
224
225
226
227
    });

    return out;
}

Tensor quant_static(Tensor x, float scale) {
    Tensor out = Tensor::empty(x.shape, Tensor::INT8, x.device());

    constexpr int unroll = 8;

    assert((uintptr_t)x.data_ptr() % (x.scalar_size() * unroll) == 0);

    int threadsPerBlock = 1024;
Muyang Li's avatar
Muyang Li committed
228
    int blocksPerGrid   = (x.numel() + threadsPerBlock * unroll - 1) / (threadsPerBlock * unroll);
Zhekai Zhang's avatar
Zhekai Zhang committed
229

fengzch-das's avatar
fengzch-das committed
230
    auto stream = getCurrentCUDAStream();
Zhekai Zhang's avatar
Zhekai Zhang committed
231
232

    dispatch(x.scalar_type(), [&]<typename scalar_t>() {
fengzch-das's avatar
fengzch-das committed
233
        quant_kernel_static<scalar_t, unroll><<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
Zhekai Zhang's avatar
Zhekai Zhang committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
            x.data_ptr<scalar_t>(), out.data_ptr<int8_t>(), (scalar_t)scale, x.numel());
    });

    return out;
}

Tensor quant_static_fuse_gelu(Tensor x, float scale) {
    Tensor out = Tensor::empty(x.shape, Tensor::INT8, x.device());

    constexpr int unroll = 8;

    assert((uintptr_t)x.data_ptr() % (x.scalar_size() * unroll) == 0);

    int threadsPerBlock = 1024;
Muyang Li's avatar
Muyang Li committed
248
    int blocksPerGrid   = (x.numel() + threadsPerBlock * unroll - 1) / (threadsPerBlock * unroll);
Zhekai Zhang's avatar
Zhekai Zhang committed
249

fengzch-das's avatar
fengzch-das committed
250
    auto stream = getCurrentCUDAStream();
Zhekai Zhang's avatar
Zhekai Zhang committed
251
252

    dispatch(x.scalar_type(), [&]<typename scalar_t>() {
fengzch-das's avatar
fengzch-das committed
253
        quant_kernel_static_fuse_gelu<scalar_t, unroll><<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
Zhekai Zhang's avatar
Zhekai Zhang committed
254
255
256
257
258
259
            x.data_ptr<scalar_t>(), out.data_ptr<int8_t>(), (scalar_t)scale, x.numel());
    });

    return out;
}

260
261
262
263
264
void cast(Tensor input, Tensor output) {
    assert(input.is_contiguous());
    assert(output.is_contiguous());
    assert(input.shape.dataExtent == output.shape.dataExtent);

265
266
267
268
    if (input.data_ptr() == output.data_ptr()) {
        assert(input.scalar_size() == output.scalar_size());
    }

fengzch-das's avatar
fengzch-das committed
269
    auto stream = getCurrentCUDAStream();
270
271
272
273
274
275

    dispatch(input.scalar_type(), [&]<typename input_t>() {
        dispatch(output.scalar_type(), [&]<typename output_t>() {
            constexpr int unroll = 16 / std::max(sizeof(input_t), sizeof(output_t));

            int threadsPerBlock = 1024;
Muyang Li's avatar
Muyang Li committed
276
            int blocksPerGrid   = (int)ceilDiv<int64_t>(input.numel(), threadsPerBlock * unroll);
277

fengzch-das's avatar
fengzch-das committed
278
            cast_kernel<input_t, output_t, unroll><<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
279
280
                input.data_ptr<input_t>(), output.data_ptr<output_t>(), input.numel());

fengzch-das's avatar
fengzch-das committed
281
            checkCUDA(cudaGetLastError());
282
283
284
285
        });
    });
}

Zhekai Zhang's avatar
Zhekai Zhang committed
286
287
288
Tensor topk(Tensor x, int k) {
    constexpr int MAXK = 64 + 4;

Muyang Li's avatar
Muyang Li committed
289
    const int N     = x.shape[-1];
Zhekai Zhang's avatar
Zhekai Zhang committed
290
291
292
293
294
    const int batch = x.numel() / N;

    assert(k <= N);
    assert(k <= MAXK);

muyangli's avatar
muyangli committed
295
    auto outShape = TensorShape(x.shape.dataExtent);
Muyang Li's avatar
Muyang Li committed
296
    outShape[-1]  = k;
Zhekai Zhang's avatar
Zhekai Zhang committed
297
298
299
300
    outShape.dataStride.clear();

    Tensor out = Tensor::empty(outShape, Tensor::INT32, x.device());

fengzch-das's avatar
fengzch-das committed
301
    auto stream = getCurrentCUDAStream();
Zhekai Zhang's avatar
Zhekai Zhang committed
302
303
304
305
306
307
308
309

    dispatchVal(k, std::make_integer_sequence<int, MAXK + 1>(), [&]<int K>() {
        if constexpr (K == 0) {
            assert(false);
            return;
        }
        if constexpr (K > 0) {
            dispatch(x.scalar_type(), [&]<typename scalar_t>() {
fengzch-das's avatar
fengzch-das committed
310
                topk_kernel<scalar_t, K><<<ceilDiv(batch, 32), 32, 0, stream>>>(
Muyang Li's avatar
Muyang Li committed
311
                    x.data_ptr<scalar_t>(), out.data_ptr<int>(), N, x.stride(-2), batch);
fengzch-das's avatar
fengzch-das committed
312
                checkCUDA(cudaGetLastError());
Zhekai Zhang's avatar
Zhekai Zhang committed
313
314
315
316
317
318
319
320
321
322
323
            });
        }
    });

    return out;
}

template std::array<Tensor, 2> split_mod<2>(Tensor input);
template std::array<Tensor, 3> split_mod<3>(Tensor input);
template std::array<Tensor, 4> split_mod<4>(Tensor input);
template std::array<Tensor, 5> split_mod<5>(Tensor input);
muyangli's avatar
muyangli committed
324
325
template std::array<Tensor, 6> split_mod<6>(Tensor input);

Muyang Li's avatar
Muyang Li committed
326
}; // namespace nunchaku::kernels