moe_compute_kernel.cu 12.4 KB
Newer Older
1
2
#include "moe_cuda_kernel.h"

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
3
4
5
6
#include <cstdio>
#include <iostream>
#include <vector>

Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
7
8
#include <cuda.h>
#include <cuda_runtime.h>
Rick Ho's avatar
Rick Ho committed
9
#include <cublas_v2.h>
Jiezhong Qiu's avatar
Jiezhong Qiu committed
10
#include <c10/cuda/CUDAGuard.h>
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
11

Rick Ho's avatar
Rick Ho committed
12
#include "timer.hh"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
13

Rick Ho's avatar
Rick Ho committed
14
15
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
16

Rick Ho's avatar
Rick Ho committed
17
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
18

Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
19
20
template <typename scalar_t>
__global__
Rick Ho's avatar
Rick Ho committed
21
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
TiagoMAntunes's avatar
TiagoMAntunes committed
22
23
24
25
26
        const long* offset, const scalar_t** ptrs) { 
    size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
    if (idx < n) {
        ptrs[idx] = base + stride * offset[idx];
    }
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
27
28
}

29
30
template <typename scalar_t>
__global__
31
void batch_scatter_kernel(size_t wid, const long* pos, 
TiagoMAntunes's avatar
TiagoMAntunes committed
32
33
34
35
36
37
        const scalar_t* inbuf, scalar_t* oubuf) { 
    inbuf += wid * pos[blockIdx.x];
    oubuf += wid * blockIdx.x;
    for (int i = threadIdx.x; i < wid; i += blockDim.x) {
        oubuf[i] = inbuf[i];
    }
38
39
}

40
41

/*
TiagoMAntunes's avatar
TiagoMAntunes committed
42
    This function is to be called with one block per each column
43
44
45
46
*/
template <typename scalar_t>
__global__ 
void column_reduce(const scalar_t * matrix, scalar_t * result, 
TiagoMAntunes's avatar
TiagoMAntunes committed
47
    int m /* lines */, int n /* columns*/) {
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    extern __shared__ float sdata[];
    unsigned int tid = threadIdx.x; // line
    unsigned int i = blockIdx.x + threadIdx.x * n; // get to idx th line
    unsigned int offset = 0;
    unsigned int it = n * blockDim.x; // advanced blockDim.x threads vertically

    // sum all the values from that column to fit in one single block
    sdata[tid] = 0;
    while (i + offset < n*m) {
        sdata[tid] += matrix[i + offset];
        offset += it; 
        
    }
    __syncthreads();

    for (unsigned int s = 1; tid + s < blockDim.x; s *= 2) {
        if (tid % (2*s) == 0) {
            sdata[tid] += sdata[tid + s];
        }

        __syncthreads();
    }
    if (tid == 0) {result[blockIdx.x] = sdata[0];}

}

Rick Ho's avatar
Rick Ho committed
74
void moe_cuda_expert_count_impl(
Rick Ho's avatar
Rick Ho committed
75
        const int* d_gate,
TiagoMAntunes's avatar
TiagoMAntunes committed
76
77
78
        int* expert_count,
        int* d_pos,
        const size_t num_expert,
Rick Ho's avatar
Rick Ho committed
79
        const size_t batch_size) {
Rick Ho's avatar
Rick Ho committed
80
    int *gate = new int[batch_size];
TiagoMAntunes's avatar
TiagoMAntunes committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    int *expert_ptr = new int[num_expert];
    memset(expert_count, 0, sizeof(int) * num_expert);

    checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
                cudaMemcpyDeviceToHost));

    for (int i = 0; i < batch_size; ++i) {
        ++expert_count[gate[i]];
    }
    expert_ptr[0] = 0;
    for (int i = 1; i < num_expert; ++i) {
        expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
    }

    int *pos = new int[batch_size];

    for (int i = 0; i < batch_size; ++i) {
        pos[i] = expert_ptr[gate[i]]++;
    }
    for (int i = num_expert - 1; i > 0; --i) {
        expert_ptr[i] = expert_ptr[i - 1];
    }
    expert_ptr[0] = 0;
    checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
                cudaMemcpyHostToDevice));
    delete [] gate;
    delete [] expert_ptr;
Rick Ho's avatar
Rick Ho committed
108
}
109

Rick Ho's avatar
Rick Ho committed
110
111
112
template <typename scalar_t>
void moe_cuda_local_scatter_impl(
        const scalar_t* input,
TiagoMAntunes's avatar
TiagoMAntunes committed
113
114
115
116
117
118
119
120
121
        const long* d_pos,
        scalar_t* input_buf,
        const long batch_size,
        const long in_feat, 
        CudaStreamManager* smgr) {
    batch_scatter_kernel<scalar_t>
        <<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input,
                input_buf); 
    smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
122
}
Rick Ho's avatar
Rick Ho committed
123

Rick Ho's avatar
Rick Ho committed
124
125
template <typename scalar_t>
__global__
126
void batch_gather_kernel(size_t wid, const long* pos, 
TiagoMAntunes's avatar
TiagoMAntunes committed
127
128
129
130
131
132
        const scalar_t* inbuf, scalar_t* oubuf) { 
    inbuf += wid * blockIdx.x;
    oubuf += wid * pos[blockIdx.x];
    for (int i = threadIdx.x; i < wid; i += blockDim.x) {
        oubuf[i] = inbuf[i];
    }
Rick Ho's avatar
Rick Ho committed
133
134
135
136
137
}

template <typename scalar_t>
void moe_cuda_local_gather_impl(
        const scalar_t* output_buf,
TiagoMAntunes's avatar
TiagoMAntunes committed
138
139
140
141
142
143
144
145
146
        const long* d_pos,
        scalar_t* output,
        const size_t batch_size,
        const size_t out_feat,
        CudaStreamManager* smgr) {
    batch_gather_kernel<scalar_t>
        <<<batch_size, 256, 0, smgr->stream(0)>>>(out_feat, d_pos, output_buf,
                output); 
    smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
147
}
Rick Ho's avatar
Rick Ho committed
148

Rick Ho's avatar
Rick Ho committed
149
150
151
152
template <typename scalar_t>
void moe_cuda_forward_impl(
        const scalar_t* input_buf,
        const scalar_t* weight,
TiagoMAntunes's avatar
TiagoMAntunes committed
153
        const long* expert_count,
Rick Ho's avatar
Rick Ho committed
154
        scalar_t* output_buf,
TiagoMAntunes's avatar
TiagoMAntunes committed
155
        const bool has_bias,
Rick Ho's avatar
Rick Ho committed
156
157
        const size_t in_feat,
        const size_t out_feat,
158
        const size_t num_expert,
TiagoMAntunes's avatar
TiagoMAntunes committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        CudaStreamManager* smgr) {
    scalar_t alpha = 1, beta = has_bias ? 1 : 0; 

    for (int i = 0, ptr = 0; i < num_expert; ++i) {
        if (expert_count[i] == 0) {
            continue;
        }
        // Use T(B) x T(A) = T(C) to produce row-major C
        checkCudaErrors(cublasXgemm(
                smgr->handle(i),
                CUBLAS_OP_T,
                CUBLAS_OP_N,
                out_feat, expert_count[i], in_feat,
                &alpha,
                weight + i * in_feat * out_feat, in_feat,
                input_buf + ptr * in_feat, in_feat,
                &beta,
                output_buf + out_feat * ptr, out_feat
                ));

        ptr += expert_count[i];
    }
    smgr->sync(num_expert);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
182
183
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
184
template <typename scalar_t>
Rick Ho's avatar
Rick Ho committed
185
186
187
void moe_cuda_backward_impl(
        const scalar_t* grad_output_buf,
        const scalar_t* input_buf,
TiagoMAntunes's avatar
TiagoMAntunes committed
188
189
        const scalar_t* weight,
        const long* expert_count,
Rick Ho's avatar
Rick Ho committed
190
191
        scalar_t* grad_input_buf,
        scalar_t* grad_weight,
TiagoMAntunes's avatar
TiagoMAntunes committed
192
193
        scalar_t* grad_bias,
        const bool has_bias,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
194
195
196
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
197
        const size_t num_expert,
TiagoMAntunes's avatar
TiagoMAntunes committed
198
        CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
199
    scalar_t alpha = 1, beta = 0;
Jiezhong Qiu's avatar
Jiezhong Qiu committed
200

TiagoMAntunes's avatar
TiagoMAntunes committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    for (int i = 0, ptr = 0; i < num_expert; ++i) {
        if (expert_count[i] == 0) {
            cudaMemset(grad_weight + i * in_feat * out_feat, 0, 
                    sizeof(scalar_t) * in_feat * out_feat);
            cudaMemset(grad_bias + i * out_feat, 0, sizeof(scalar_t) * out_feat);
            continue;
        }
        // Use T(B) x T(A) = T(C) to produce row-major C

        // Backward input: g_i = w @ g_o
        checkCudaErrors(cublasXgemm(
                smgr->handle(i),
                CUBLAS_OP_N,
                CUBLAS_OP_N,
                in_feat, expert_count[i], out_feat,
                &alpha,
                weight + i * in_feat * out_feat, in_feat,
                grad_output_buf + ptr * out_feat, out_feat,
                &beta,
                grad_input_buf + in_feat * ptr, in_feat
                ));

        // Backward weight: g_w = i @ g_o
        checkCudaErrors(cublasXgemm(
                smgr->handle(i),
                CUBLAS_OP_N,
                CUBLAS_OP_T,
                in_feat, out_feat, expert_count[i],
                &alpha,
                input_buf + in_feat * ptr, in_feat,
                grad_output_buf + ptr * out_feat, out_feat,
                &beta,
                grad_weight + i * in_feat * out_feat, in_feat
                ));
        
        if (has_bias) {
            column_reduce
            <<<out_feat, 1024, sizeof(scalar_t)*1024, smgr->stream(0)>>>
            (
                grad_output_buf + ptr * out_feat,
                grad_bias + i * out_feat,
                expert_count[i],
                out_feat
            );
        }

        ptr += expert_count[i];
    }
    smgr->sync(num_expert);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
250
}
Rick Ho's avatar
Rick Ho committed
251
252


Rick Ho's avatar
Rick Ho committed
253
std::vector<torch::Tensor> moe_cuda_expert_count(
TiagoMAntunes's avatar
TiagoMAntunes committed
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
        torch::Tensor gate, 
        size_t num_expert) {
    const auto batch_size = gate.size(0);

    auto ec_options = torch::TensorOptions().dtype(torch::kInt32);
    auto expert_count = torch::empty(num_expert, ec_options);

    auto pos_options = torch::TensorOptions()
        .device(gate.device())
        .dtype(torch::kInt32);
    auto pos = torch::empty(batch_size, pos_options);
    moe_cuda_expert_count_impl(
            gate.data_ptr<int>(),
            expert_count.data_ptr<int>(),
            pos.data_ptr<int>(),
            num_expert,
            batch_size);

    return {expert_count, pos};
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
273
274
}

Rick Ho's avatar
Rick Ho committed
275
276
std::vector<torch::Tensor> moe_cuda_local_scatter(
    torch::Tensor input,
TiagoMAntunes's avatar
TiagoMAntunes committed
277
278
279
    torch::Tensor pos) {
    auto smgr = getCudaStreamManager(input.device().index());
    const auto batch_size = pos.size(0);
Rick Ho's avatar
Rick Ho committed
280
281
    const auto in_feat = input.size(1);

TiagoMAntunes's avatar
TiagoMAntunes committed
282
283
284
285
    auto opt = torch::TensorOptions()
        .dtype(input.dtype())
        .device(input.device());
    auto input_buf = torch::empty({batch_size, in_feat}, opt);
Rick Ho's avatar
Rick Ho committed
286

Rick Ho's avatar
Rick Ho committed
287
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "moe_local_scatter_cuda", 
TiagoMAntunes's avatar
TiagoMAntunes committed
288
289
290
291
292
293
294
295
296
297
            ([&] {
        moe_cuda_local_scatter_impl<scalar_t>(
            input.data_ptr<scalar_t>(),
            pos.data_ptr<long>(),
            input_buf.data_ptr<scalar_t>(),
            batch_size,
            in_feat,
            smgr);
    }));
    return {input_buf,};
Rick Ho's avatar
Rick Ho committed
298
}
Jiezhong Qiu's avatar
Jiezhong Qiu committed
299

Rick Ho's avatar
Rick Ho committed
300
std::vector<torch::Tensor> moe_cuda_local_gather(
TiagoMAntunes's avatar
TiagoMAntunes committed
301
302
303
304
    torch::Tensor output_buf,
    torch::Tensor pos) {
    auto smgr = getCudaStreamManager(output_buf.device().index());
    const auto batch_size = pos.size(0);
Rick Ho's avatar
Rick Ho committed
305
306
    const auto out_feat = output_buf.size(1);

TiagoMAntunes's avatar
TiagoMAntunes committed
307
308
309
310
    auto opt = torch::TensorOptions()
        .dtype(output_buf.dtype())
        .device(output_buf.device());
    auto output = torch::empty({batch_size, out_feat}, opt);
Rick Ho's avatar
Rick Ho committed
311

Rick Ho's avatar
Rick Ho committed
312
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(), "moe_local_gather_cuda", 
TiagoMAntunes's avatar
TiagoMAntunes committed
313
314
315
316
317
318
319
320
321
322
            ([&] {
        moe_cuda_local_gather_impl<scalar_t>(
            output_buf.data_ptr<scalar_t>(),
            pos.data_ptr<long>(),
            output.data_ptr<scalar_t>(),
            batch_size,
            out_feat,
            smgr);
    }));
    return {output,};
Jiezhong Qiu's avatar
Jiezhong Qiu committed
323
}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
324

Jiezhong Qiu's avatar
Jiezhong Qiu committed
325
std::vector<torch::Tensor> moe_cuda_forward(
Rick Ho's avatar
Rick Ho committed
326
        torch::Tensor input_buf,
TiagoMAntunes's avatar
TiagoMAntunes committed
327
        torch::Tensor expert_count,
328
        torch::Tensor weight,
TiagoMAntunes's avatar
TiagoMAntunes committed
329
330
331
332
        at::optional<torch::Tensor> bias
        ) {
    auto smgr = getCudaStreamManager(input_buf.device().index());
    const auto batch_size = input_buf.size(0);
Rick Ho's avatar
Rick Ho committed
333
334
335
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
336
            
Rick Ho's avatar
Rick Ho committed
337
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
338
    printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", 
TiagoMAntunes's avatar
TiagoMAntunes committed
339
            num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
340
#endif
341
342

    torch::Tensor output;
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
343
    
TiagoMAntunes's avatar
TiagoMAntunes committed
344
345
346
347
348
349
350
351
352
    if (bias.has_value()) {
        output = bias.value().repeat_interleave(expert_count.to(bias.value().device()), 0);
    } else{
        auto out_options = torch::TensorOptions()
            .device(input_buf.device())
            .dtype(input_buf.dtype());
        output = torch::empty({batch_size, out_feat}, out_options);
    }
        
Rick Ho's avatar
Rick Ho committed
353
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_forward_cuda", 
TiagoMAntunes's avatar
TiagoMAntunes committed
354
355
356
357
358
359
360
361
362
363
364
365
            ([&] {
        moe_cuda_forward_impl<scalar_t>(
            input_buf.data_ptr<scalar_t>(),
            weight.data_ptr<scalar_t>(),
            expert_count.data_ptr<long>(),
            output.data_ptr<scalar_t>(),
            bias.has_value(),
            in_feat,
            out_feat,
            num_expert,
            smgr
        );
Jiezhong Qiu's avatar
updarte  
Jiezhong Qiu committed
366
367
368
369
370
    }));
    
    return {output, };           
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
371
std::vector<torch::Tensor> moe_cuda_backward(
372
373
    torch::Tensor grad_output_buf, 	// [batch_size x out_feat]
    torch::Tensor input_buf, 		// [batch_size x out_feat]
TiagoMAntunes's avatar
TiagoMAntunes committed
374
    torch::Tensor expert_count,
375
    torch::Tensor weight, 			// [num_expert x out_feat x in_feat]
TiagoMAntunes's avatar
TiagoMAntunes committed
376
    at::optional<torch::Tensor> bias
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
377
) {
TiagoMAntunes's avatar
TiagoMAntunes committed
378
    auto smgr = getCudaStreamManager(input_buf.device().index());
Rick Ho's avatar
Rick Ho committed
379
    const auto batch_size = input_buf.size(0);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
380
381
382
    const auto num_expert = weight.size(0);
    const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);
Jiezhong Qiu's avatar
Jiezhong Qiu committed
383

Rick Ho's avatar
Rick Ho committed
384
#ifdef MOE_DEBUG
Rick Ho's avatar
Rick Ho committed
385
    printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
TiagoMAntunes's avatar
TiagoMAntunes committed
386
387
            "out_feat (d_ffn)=%ld\n",
            batch_size, num_expert, in_feat, out_feat);
Rick Ho's avatar
Rick Ho committed
388
#endif
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
389

Rick Ho's avatar
Rick Ho committed
390
391
    auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat}); 
    auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
TiagoMAntunes's avatar
TiagoMAntunes committed
392
    auto grad_bias = grad_output_buf.new_empty({num_expert, out_feat});
Jiezhong Qiu's avatar
Jiezhong Qiu committed
393

Rick Ho's avatar
Rick Ho committed
394
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
Rick Ho's avatar
Rick Ho committed
395
396
397
        moe_cuda_backward_impl<scalar_t>(
            grad_output_buf.data_ptr<scalar_t>(),
            input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
398
            weight.data_ptr<scalar_t>(),
TiagoMAntunes's avatar
TiagoMAntunes committed
399
            expert_count.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
400
            grad_input_buf.data_ptr<scalar_t>(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
401
            grad_weight.data_ptr<scalar_t>(),
TiagoMAntunes's avatar
TiagoMAntunes committed
402
403
            grad_bias.data_ptr<scalar_t>(),
            bias.has_value(),
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
404
405
            batch_size,
            in_feat,
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
406
            out_feat,
407
            num_expert,
TiagoMAntunes's avatar
TiagoMAntunes committed
408
            smgr
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
409
410
411
        );
    }));

TiagoMAntunes's avatar
TiagoMAntunes committed
412
    return {grad_input_buf, grad_weight, grad_bias};
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
413
}